Skip to main content

gam_gpu/
blas.rs

1//! Device BLAS surface for the cudarc-backed dense kernels.
2//!
3//! The public surface here is the lowest level of the GPU dispatch stack: it
4//! takes ndarray views, copies them to a device buffer, calls a cuBLAS / kernel
5//! routine, and returns the host result. The cudarc-backed implementations
6//! always compile (cudarc dynamically loads `libcuda` at runtime via the
7//! `fallback-dynamic-loading` feature), and dispatch is gated at runtime on
8//! `super::device_runtime::GpuRuntime::global()` — when no device is probed the
9//! status enum advertises `CudaUnavailable` and callers fall back to CPU.
10//!
11//! The implementations route through `super::device_runtime::cuda_context_for` and
12//! the cudarc 0.19 cuBLAS API. Any transient backend failure (OOM, launch
13//! error, …) is converted to `None` so the auto-dispatch shim in
14//! `super::linalg` falls back to the CPU fast path without disturbing
15//! numerics.
16
17pub fn blas_backend_status() -> super::CudaBackendStatus {
18    super::cuda_backend_status()
19}
20
21#[cfg(target_os = "linux")]
22mod cuda_impl {
23    use ndarray::{Array1, Array2, Array3, ArrayView1, ArrayView2, ArrayView3, Axis};
24
25    use crate::driver::{from_col_major, to_col_major, to_i32};
26
27    use super::super::device_runtime::GpuRuntime;
28    use cudarc::cublas::sys::{
29        cublasDiagType_t, cublasFillMode_t, cublasOperation_t, cublasSideMode_t, cublasStatus_t,
30    };
31    use cudarc::cublas::{CudaBlas, Gemm, GemmConfig, Gemv, GemvConfig, StridedBatchedConfig};
32    use cudarc::cusolver::{DnHandle, sys as cusolver_sys};
33    use cudarc::driver::{CudaSlice, CudaStream, DevicePtr, DevicePtrMut};
34    use std::sync::Arc;
35
36    /// Create a fresh stream + cuBLAS handle bound to a specific device
37    /// ordinal. This is the per-ordinal entry point used by multi-GPU fan-out
38    /// (`super::super::pool::scatter_batched` workers): the worker thread has
39    /// already bound that ordinal's context, and the stream/handle created here
40    /// target the same device. The single-device helper below is the
41    /// primary-ordinal specialization.
42    #[inline]
43    pub(crate) fn stream_and_blas_for(ordinal: usize) -> Option<(Arc<CudaStream>, CudaBlas)> {
44        let stream = super::super::device_runtime::cuda_context_for(ordinal)?
45            .new_stream()
46            .ok()?;
47        let blas = CudaBlas::new(stream.clone()).ok()?;
48        Some((stream, blas))
49    }
50
51    #[inline]
52    fn stream_and_blas(runtime: &GpuRuntime) -> Option<(Arc<CudaStream>, CudaBlas)> {
53        stream_and_blas_for(runtime.device.ordinal)
54    }
55
56    #[inline]
57    fn vector_values(v: ArrayView1<'_, f64>) -> Vec<f64> {
58        v.iter().copied().collect()
59    }
60
61    #[inline]
62    fn to_col_major_batch(batch: ArrayView3<'_, f64>) -> Vec<f64> {
63        let (batch_len, rows, cols) = batch.dim();
64        let mut out = Vec::with_capacity(batch_len.saturating_mul(rows).saturating_mul(cols));
65        for matrix in batch.axis_iter(Axis(0)) {
66            out.extend(to_col_major(&matrix).iter().copied());
67        }
68        out
69    }
70
71    #[inline]
72    fn from_col_major_batch(
73        data: &[f64],
74        batch: usize,
75        rows: usize,
76        cols: usize,
77    ) -> Option<Array3<f64>> {
78        if data.len() != batch.checked_mul(rows)?.checked_mul(cols)? {
79            return None;
80        }
81        let mut out = Array3::<f64>::zeros((batch, rows, cols));
82        let matrix_len = rows.checked_mul(cols)?;
83        for batch_idx in 0..batch {
84            let base = batch_idx.checked_mul(matrix_len)?;
85            for col in 0..cols {
86                for row in 0..rows {
87                    out[[batch_idx, row, col]] = data[base + col * rows + row];
88                }
89            }
90        }
91        Some(out)
92    }
93
94    #[inline]
95    fn row_scale_device(
96        blas: &CudaBlas,
97        stream: &Arc<CudaStream>,
98        matrix_dev: &CudaSlice<f64>,
99        weights_dev: &CudaSlice<f64>,
100        scaled_dev: &mut CudaSlice<f64>,
101        rows: usize,
102        cols: usize,
103    ) -> Option<()> {
104        let rows_i = to_i32(rows)?;
105        let cols_i = to_i32(cols)?;
106        let handle = *blas.handle();
107        let (matrix_ptr, _matrix_record) = matrix_dev.device_ptr(stream);
108        let (weights_ptr, _weights_record) = weights_dev.device_ptr(stream);
109        let (scaled_ptr, _scaled_record) = scaled_dev.device_ptr_mut(stream);
110        // SAFETY: all device slices are on this stream/context. `matrix_dev`
111        // and `scaled_dev` are rows×cols column-major matrices with lda/ldc
112        // equal to rows; `weights_dev` has one contiguous value per row.
113        let status = unsafe {
114            cudarc::cublas::sys::cublasDdgmm(
115                handle,
116                cublasSideMode_t::CUBLAS_SIDE_LEFT,
117                rows_i,
118                cols_i,
119                matrix_ptr as *const f64,
120                rows_i,
121                weights_ptr as *const f64,
122                1,
123                scaled_ptr as *mut f64,
124                rows_i,
125            )
126        };
127        if status == cublasStatus_t::CUBLAS_STATUS_SUCCESS {
128            Some(())
129        } else {
130            None
131        }
132    }
133
134    #[inline]
135    fn weighted_crossprod(
136        runtime: &GpuRuntime,
137        left: ArrayView2<'_, f64>,
138        weights: ArrayView1<'_, f64>,
139        right: ArrayView2<'_, f64>,
140    ) -> Option<Array2<f64>> {
141        weighted_crossprod_for(runtime.device.ordinal, left, weights, right)
142    }
143
144    #[inline]
145    fn weighted_crossprod_for(
146        ordinal: usize,
147        left: ArrayView2<'_, f64>,
148        weights: ArrayView1<'_, f64>,
149        right: ArrayView2<'_, f64>,
150    ) -> Option<Array2<f64>> {
151        let (rows, left_cols) = left.dim();
152        let (right_rows, right_cols) = right.dim();
153        if rows == 0
154            || left_cols == 0
155            || right_cols == 0
156            || rows != right_rows
157            || rows != weights.len()
158        {
159            return None;
160        }
161
162        let (stream, blas) = stream_and_blas_for(ordinal)?;
163        // #1412: the symmetric Gram `Xᵀ·diag(w)·X` (xt_diag_x) passes the SAME
164        // array as `left` and `right`. Detect that (identical data pointer +
165        // shape) and stage `X` ONCE instead of column-majoring and H2D-uploading
166        // two byte-identical n×p copies — halving the dominant H2D for the Gram.
167        // The GEMM operands are unchanged (`left_dev` doubles as the row-scale
168        // source), so the result is bit-identical to the two-upload path.
169        let same_operand = std::ptr::eq(left.as_ptr(), right.as_ptr())
170            && left.dim() == right.dim()
171            && left.strides() == right.strides();
172        let left_col = to_col_major(&left);
173        let weights_host = vector_values(weights);
174        let left_dev = stream.clone_htod(&*left_col).ok()?;
175        // Symmetric Gram: `right` IS `left`, so row-scale directly from the
176        // single resident `left_dev` and never upload a second n×p copy. The
177        // asymmetric path uploads `right` as before.
178        let right_dev = if same_operand {
179            None
180        } else {
181            let right_col = to_col_major(&right);
182            Some(stream.clone_htod(&*right_col).ok()?)
183        };
184        let weights_dev = stream.clone_htod(&weights_host).ok()?;
185        let mut weighted_right_dev = stream
186            .alloc_zeros::<f64>(rows.checked_mul(right_cols)?)
187            .ok()?;
188        row_scale_device(
189            &blas,
190            &stream,
191            right_dev.as_ref().unwrap_or(&left_dev),
192            &weights_dev,
193            &mut weighted_right_dev,
194            rows,
195            right_cols,
196        )?;
197
198        let mut out_dev = stream
199            .alloc_zeros::<f64>(left_cols.checked_mul(right_cols)?)
200            .ok()?;
201        let cfg = GemmConfig::<f64> {
202            transa: cublasOperation_t::CUBLAS_OP_T,
203            transb: cublasOperation_t::CUBLAS_OP_N,
204            m: to_i32(left_cols)?,
205            n: to_i32(right_cols)?,
206            k: to_i32(rows)?,
207            alpha: 1.0,
208            lda: to_i32(rows)?,
209            ldb: to_i32(rows)?,
210            beta: 0.0,
211            ldc: to_i32(left_cols)?,
212        };
213        // SAFETY: cfg computes leftᵀ (left_cols×rows) times weighted_right
214        // (rows×right_cols) into a left_cols×right_cols column-major output.
215        unsafe { blas.gemm(cfg, &left_dev, &weighted_right_dev, &mut out_dev) }.ok()?;
216        let out_col = stream.clone_dtoh(&out_dev).ok()?;
217        from_col_major(&out_col, left_cols, right_cols)
218    }
219
220    /// #1017 Phase 3: a device-resident design matrix `X` whose `n×p` values are
221    /// uploaded to the device ONCE and reused across many `Xᵀ·diag(w)·X` Gram
222    /// evaluations.
223    ///
224    /// The per-call [`xt_diag_x_cuda`] path re-uploads the full `n×p` `X` (and a
225    /// second copy as the `right` operand) on EVERY call. For the SAE / IRLS
226    /// inner loop — where `X` is frozen across weight updates and the Gram is
227    /// rebuilt once per Newton/PIRLS step — that H2D staging dominates the wall
228    /// clock (measured #1412: the `XtWX` GEMM is ~98% of the pipeline at <20% GPU
229    /// utilisation, i.e. the device is starved by the per-call upload, not the
230    /// arithmetic). Uploading `X` once and crossing only the `n`-vector `w` (and
231    /// the `p×p` result) per call removes that ping-pong: the resident `X` is
232    /// `n·p` doubles vs the per-call `w` of `n` doubles, so the amortised
233    /// transfer per Gram drops by a factor of `p`.
234    pub(crate) struct ResidentWeightedGram {
235        stream: Arc<CudaStream>,
236        blas: CudaBlas,
237        x_dev: CudaSlice<f64>,
238        rows: usize,
239        cols: usize,
240    }
241
242    impl ResidentWeightedGram {
243        /// Upload `x` (`n×p`) to `ordinal` once, column-major, and keep it
244        /// resident. Returns `None` on a degenerate shape or any device failure
245        /// (the caller falls back to the per-call CPU/GPU path).
246        pub(crate) fn new(ordinal: usize, x: ArrayView2<'_, f64>) -> Option<Self> {
247            let (rows, cols) = x.dim();
248            if rows == 0 || cols == 0 {
249                return None;
250            }
251            let (stream, blas) = stream_and_blas_for(ordinal)?;
252            let x_col = to_col_major(&x);
253            let x_dev = stream.clone_htod(&*x_col).ok()?;
254            Some(Self {
255                stream,
256                blas,
257                x_dev,
258                rows,
259                cols,
260            })
261        }
262
263        #[inline]
264        pub(crate) fn dims(&self) -> (usize, usize) {
265            (self.rows, self.cols)
266        }
267
268        /// Compute `Xᵀ·diag(w)·X` reusing the resident `X`. Only `w` (`n`
269        /// doubles) crosses H2D and only the `p×p` Gram crosses D2H. The
270        /// arithmetic is bit-identical to [`xt_diag_x_cuda`] on the same device
271        /// (same `cublasDdgmm` row-scale + same `gemm` reduction order).
272        pub(crate) fn gram(&self, w: ArrayView1<'_, f64>) -> Option<Array2<f64>> {
273            if w.len() != self.rows {
274                return None;
275            }
276            let weights_host = vector_values(w);
277            let weights_dev = self.stream.clone_htod(&weights_host).ok()?;
278            let mut weighted_dev = self
279                .stream
280                .alloc_zeros::<f64>(self.rows.checked_mul(self.cols)?)
281                .ok()?;
282            row_scale_device(
283                &self.blas,
284                &self.stream,
285                &self.x_dev,
286                &weights_dev,
287                &mut weighted_dev,
288                self.rows,
289                self.cols,
290            )?;
291            let mut out_dev = self
292                .stream
293                .alloc_zeros::<f64>(self.cols.checked_mul(self.cols)?)
294                .ok()?;
295            let cfg = GemmConfig::<f64> {
296                transa: cublasOperation_t::CUBLAS_OP_T,
297                transb: cublasOperation_t::CUBLAS_OP_N,
298                m: to_i32(self.cols)?,
299                n: to_i32(self.cols)?,
300                k: to_i32(self.rows)?,
301                alpha: 1.0,
302                lda: to_i32(self.rows)?,
303                ldb: to_i32(self.rows)?,
304                beta: 0.0,
305                ldc: to_i32(self.cols)?,
306            };
307            // SAFETY: `x_dev` is the resident n×p column-major design; cfg forms
308            // Xᵀ (p×n) · weighted (n×p) → a p×p column-major Gram.
309            unsafe {
310                self.blas
311                    .gemm(cfg, &self.x_dev, &weighted_dev, &mut out_dev)
312            }
313            .ok()?;
314            let out_col = self.stream.clone_dtoh(&out_dev).ok()?;
315            from_col_major(&out_col, self.cols, self.cols)
316        }
317
318        /// Compute the resident weighted Gram `G = Xᵀ·diag(w)·X + ridge·I`,
319        /// factor it (cuSOLVER POTRF), and solve `G·β = rhs` — keeping `G`, its
320        /// Cholesky factor, and the RHS all DEVICE-RESIDENT. Only `w` (`n`),
321        /// `rhs` (`p`), and the result `β` (`p`) cross the PCIe boundary; the
322        /// `p×p` Gram is NEVER downloaded.
323        ///
324        /// This is the #1017 Phase-3 ceiling fix for the normal-equations solve:
325        /// the per-call [`gram`] still pays a `p×p` D2H (134 MB at p=4096 — the
326        /// next bottleneck once `X` is resident), whereas the SAE/IRLS inner step
327        /// only needs the `p`-vector `β = (XᵀWX+λ)⁻¹ XᵀWz`. Chaining
328        /// row-scale→GEMM→POTRF→TRSM on-device and returning only `β` removes the
329        /// Gram transfer entirely.
330        ///
331        /// `ridge` (e.g. the penalty diagonal `λ` or a Tikhonov floor) is seeded
332        /// as `ridge·I` on the device and the Gram is GEMM-accumulated onto it
333        /// (`beta = 1`), so the diagonal bump never costs a Gram round-trip.
334        /// Returns `None` on shape mismatch, a non-PD factorisation, or any
335        /// device failure (the caller falls back to the CPU solve).
336        pub(crate) fn solve_psd_normal_equations(
337            &self,
338            w: ArrayView1<'_, f64>,
339            rhs: ArrayView1<'_, f64>,
340            ridge: f64,
341        ) -> Option<Array1<f64>> {
342            if w.len() != self.rows || rhs.len() != self.cols {
343                return None;
344            }
345            let p = self.cols;
346
347            // weighted = diag(w) · X  (resident X row-scaled).
348            let weights_dev = self.stream.clone_htod(&vector_values(w)).ok()?;
349            let mut weighted_dev = self
350                .stream
351                .alloc_zeros::<f64>(self.rows.checked_mul(p)?)
352                .ok()?;
353            row_scale_device(
354                &self.blas,
355                &self.stream,
356                &self.x_dev,
357                &weights_dev,
358                &mut weighted_dev,
359                self.rows,
360                p,
361            )?;
362
363            // Pre-seed G with `ridge·I` on the device, then GEMM-accumulate
364            // `XᵀW X` onto it with `beta = 1.0`. The Gram is formed and stays
365            // device-resident: `ridge·I` is a one-time H2D upload (the only way
366            // to set a diagonal without an NVRTC kernel), and crucially the p×p
367            // Gram is NEVER read back — only `β` returns. `ridge·I` upload is
368            // bandwidth-trivial vs the avoided per-solve Gram download.
369            let mut ridge_init = vec![0.0_f64; p.checked_mul(p)?];
370            for i in 0..p {
371                ridge_init[i * p + i] = ridge;
372            }
373            let mut g_dev = self.stream.clone_htod(&ridge_init).ok()?;
374            let cfg = GemmConfig::<f64> {
375                transa: cublasOperation_t::CUBLAS_OP_T,
376                transb: cublasOperation_t::CUBLAS_OP_N,
377                m: to_i32(p)?,
378                n: to_i32(p)?,
379                k: to_i32(self.rows)?,
380                alpha: 1.0,
381                lda: to_i32(self.rows)?,
382                ldb: to_i32(self.rows)?,
383                // Accumulate onto the resident ridge·I seed.
384                beta: 1.0,
385                ldc: to_i32(p)?,
386            };
387            // SAFETY: resident n×p X and the n×p weighted buffer form a p×p Gram;
388            // beta=1 accumulates Xᵀ(WX) onto the resident ridge·I in g_dev.
389            unsafe { self.blas.gemm(cfg, &self.x_dev, &weighted_dev, &mut g_dev) }.ok()?;
390
391            // POTRF(G) → lower factor L, resident in g_dev.
392            let solver = DnHandle::new(self.stream.clone()).ok()?;
393            let info = potrf_single_dev(&solver, &self.stream, p, &mut g_dev)?;
394            if info != 0 {
395                // Not positive-definite at pivot `info`; caller falls back.
396                return None;
397            }
398
399            // Solve L Lᵀ β = rhs via two triangular solves, β resident in rhs_dev.
400            let mut rhs_dev = self.stream.clone_htod(&vector_values(rhs)).ok()?;
401            trsm_single_vec(&self.blas, &self.stream, p, &g_dev, &mut rhs_dev, false)?; // L y = rhs
402            trsm_single_vec(&self.blas, &self.stream, p, &g_dev, &mut rhs_dev, true)?; // Lᵀ β = y
403
404            // Download ONLY the p-vector solution.
405            let beta_host = self.stream.clone_dtoh(&rhs_dev).ok()?;
406            Some(Array1::from_vec(beta_host))
407        }
408    }
409
410    /// Single cuSOLVER `DPOTRF` (lower) of a resident `p×p` column-major matrix,
411    /// factored in place. Returns the cuSOLVER `info` (0 = success, k>0 = the
412    /// leading minor of order k is not PD). Mirrors the arrow-Schur frame POTRF.
413    fn potrf_single_dev(
414        solver: &DnHandle,
415        stream: &Arc<CudaStream>,
416        p: usize,
417        matrix: &mut CudaSlice<f64>,
418    ) -> Option<i32> {
419        let p_i = to_i32(p)?;
420        let uplo = cusolver_sys::cublasFillMode_t::CUBLAS_FILL_MODE_LOWER;
421        let mut lwork = 0_i32;
422        {
423            let (mat_ptr, _rec) = matrix.device_ptr_mut(stream);
424            // SAFETY: buffer-size query against a live p×p column-major matrix.
425            let status = unsafe {
426                cusolver_sys::cusolverDnDpotrf_bufferSize(
427                    solver.cu(),
428                    uplo,
429                    p_i,
430                    mat_ptr as *mut f64,
431                    p_i,
432                    &mut lwork,
433                )
434            };
435            if status != cusolver_sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
436                return None;
437            }
438        }
439        let mut workspace = stream.alloc_zeros::<f64>(lwork.max(1) as usize).ok()?;
440        let mut info_dev = stream.alloc_zeros::<i32>(1).ok()?;
441        {
442            let (mat_ptr, _rec) = matrix.device_ptr_mut(stream);
443            let (work_ptr, _wrec) = workspace.device_ptr_mut(stream);
444            let (info_ptr, _irec) = info_dev.device_ptr_mut(stream);
445            // SAFETY: all buffers live on this stream; matrix is p×p column-major.
446            let status = unsafe {
447                cusolver_sys::cusolverDnDpotrf(
448                    solver.cu(),
449                    uplo,
450                    p_i,
451                    mat_ptr as *mut f64,
452                    p_i,
453                    work_ptr as *mut f64,
454                    lwork,
455                    info_ptr as *mut i32,
456                )
457            };
458            if status != cusolver_sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
459                return None;
460            }
461        }
462        let info_host = stream.clone_dtoh(&info_dev).ok()?;
463        info_host.first().copied()
464    }
465
466    /// Triangular solve `op(L)·x = b` for a single `p`-vector RHS against a
467    /// resident lower Cholesky factor `L` (`p×p` column-major), in place over
468    /// `rhs`. `transposed` selects `Lᵀ` (the second back-substitution).
469    fn trsm_single_vec(
470        blas: &CudaBlas,
471        stream: &Arc<CudaStream>,
472        p: usize,
473        l: &CudaSlice<f64>,
474        rhs: &mut CudaSlice<f64>,
475        transposed: bool,
476    ) -> Option<()> {
477        let alpha = 1.0_f64;
478        let p_i = to_i32(p)?;
479        let handle = *blas.handle();
480        let (l_ptr, _l_rec) = l.device_ptr(stream);
481        let (rhs_ptr, _rhs_rec) = rhs.device_ptr_mut(stream);
482        // SAFETY: p×p lower factor and a single p-vector RHS, both resident.
483        let status = unsafe {
484            cudarc::cublas::sys::cublasDtrsm_v2(
485                handle,
486                cublasSideMode_t::CUBLAS_SIDE_LEFT,
487                cublasFillMode_t::CUBLAS_FILL_MODE_LOWER,
488                if transposed {
489                    cublasOperation_t::CUBLAS_OP_T
490                } else {
491                    cublasOperation_t::CUBLAS_OP_N
492                },
493                cublasDiagType_t::CUBLAS_DIAG_NON_UNIT,
494                p_i,
495                1,
496                &alpha,
497                l_ptr as *const f64,
498                p_i,
499                rhs_ptr as *mut f64,
500                p_i,
501            )
502        };
503        if status == cublasStatus_t::CUBLAS_STATUS_SUCCESS {
504            Some(())
505        } else {
506            None
507        }
508    }
509
510    #[inline]
511    fn assign_block(
512        out: &mut Array2<f64>,
513        row_offset: usize,
514        col_offset: usize,
515        block: &Array2<f64>,
516    ) {
517        let (rows, cols) = block.dim();
518        for col in 0..cols {
519            for row in 0..rows {
520                out[[row_offset + row, col_offset + col]] = block[[row, col]];
521            }
522        }
523    }
524
525    #[inline]
526    fn mirror_upper_to_lower(out: &mut Array2<f64>) {
527        let n = out.nrows();
528        for row in 0..n {
529            for col in 0..row {
530                out[[row, col]] = out[[col, row]];
531            }
532        }
533    }
534
535    #[inline]
536    pub(crate) fn gemm_cuda(
537        runtime: &GpuRuntime,
538        a: ArrayView2<'_, f64>,
539        b: ArrayView2<'_, f64>,
540        trans_a: bool,
541        trans_b: bool,
542    ) -> Option<Array2<f64>> {
543        gemm_on_ordinal_cuda(runtime.device.ordinal, a, b, trans_a, trans_b)
544    }
545
546    /// Dense GEMM (optionally transposing either operand) on a specific device
547    /// ordinal. The ordinal's context is expected to be bound on the calling
548    /// thread (pool-tiled callers via `super::super::pool::scatter_batched`, or
549    /// the single-device dispatcher through [`gemm_cuda`]). Semantics are
550    /// identical to [`gemm_cuda`]; only the target device differs.
551    #[inline]
552    pub(crate) fn gemm_on_ordinal_cuda(
553        ordinal: usize,
554        a: ArrayView2<'_, f64>,
555        b: ArrayView2<'_, f64>,
556        trans_a: bool,
557        trans_b: bool,
558    ) -> Option<Array2<f64>> {
559        let (a_rows, a_cols) = a.dim();
560        let (b_rows, b_cols) = b.dim();
561        let (m, k_a) = if trans_a {
562            (a_cols, a_rows)
563        } else {
564            (a_rows, a_cols)
565        };
566        let (k_b, n) = if trans_b {
567            (b_cols, b_rows)
568        } else {
569            (b_rows, b_cols)
570        };
571        if m == 0 || n == 0 || k_a == 0 || k_a != k_b {
572            return None;
573        }
574        let (stream, blas) = stream_and_blas_for(ordinal)?;
575        let a_col = to_col_major(&a);
576        let b_col = to_col_major(&b);
577        let a_dev = stream.clone_htod(&*a_col).ok()?;
578        let b_dev = stream.clone_htod(&*b_col).ok()?;
579        let mut out_dev = stream.alloc_zeros::<f64>(m.checked_mul(n)?).ok()?;
580        let cfg = GemmConfig::<f64> {
581            transa: if trans_a {
582                cublasOperation_t::CUBLAS_OP_T
583            } else {
584                cublasOperation_t::CUBLAS_OP_N
585            },
586            transb: if trans_b {
587                cublasOperation_t::CUBLAS_OP_T
588            } else {
589                cublasOperation_t::CUBLAS_OP_N
590            },
591            m: to_i32(m)?,
592            n: to_i32(n)?,
593            k: to_i32(k_a)?,
594            alpha: 1.0,
595            lda: to_i32(a_rows)?,
596            ldb: to_i32(b_rows)?,
597            beta: 0.0,
598            ldc: to_i32(m)?,
599        };
600        // SAFETY: buffers are column-major with dimensions validated above.
601        unsafe { blas.gemm(cfg, &a_dev, &b_dev, &mut out_dev) }.ok()?;
602        let out_col = stream.clone_dtoh(&out_dev).ok()?;
603        from_col_major(&out_col, m, n)
604    }
605
606    /// Broadcast-B batched GEMM on a specific device ordinal. The caller
607    /// (`super::super::pool::scatter_batched` worker, or the single-device
608    /// dispatcher) supplies the ordinal whose context is already bound on this
609    /// thread; the stream/handle are created on that same device.
610    #[inline]
611    pub(crate) fn gemm_broadcast_b_batched_cuda(
612        ordinal: usize,
613        a: ArrayView3<'_, f64>,
614        b: ArrayView2<'_, f64>,
615    ) -> Option<Array3<f64>> {
616        let (batch, m, k) = a.dim();
617        let (b_rows, n) = b.dim();
618        if batch == 0 || m == 0 || n == 0 || k == 0 || b_rows != k {
619            return None;
620        }
621        let (stream, blas) = stream_and_blas_for(ordinal)?;
622        let a_col = to_col_major_batch(a);
623        let b_col = to_col_major(&b);
624        let a_dev = stream.clone_htod(&a_col).ok()?;
625        let b_dev = stream.clone_htod(&*b_col).ok()?;
626        let mut out_dev = stream
627            .alloc_zeros::<f64>(batch.checked_mul(m)?.checked_mul(n)?)
628            .ok()?;
629        let cfg = StridedBatchedConfig::<f64> {
630            gemm: GemmConfig::<f64> {
631                transa: cublasOperation_t::CUBLAS_OP_N,
632                transb: cublasOperation_t::CUBLAS_OP_N,
633                m: to_i32(m)?,
634                n: to_i32(n)?,
635                k: to_i32(k)?,
636                alpha: 1.0,
637                lda: to_i32(m)?,
638                ldb: to_i32(k)?,
639                beta: 0.0,
640                ldc: to_i32(m)?,
641            },
642            batch_size: to_i32(batch)?,
643            stride_a: i64::try_from(m.checked_mul(k)?).ok()?,
644            stride_b: 0,
645            stride_c: i64::try_from(m.checked_mul(n)?).ok()?,
646        };
647        // SAFETY: `a_dev` is a stack of batch column-major m×k matrices,
648        // `b_dev` is one shared column-major k×n matrix with zero batch stride,
649        // and `out_dev` is a stack of batch column-major m×n outputs.
650        unsafe { blas.gemm_strided_batched(cfg, &a_dev, &b_dev, &mut out_dev) }.ok()?;
651        let out_col = stream.clone_dtoh(&out_dev).ok()?;
652        from_col_major_batch(&out_col, batch, m, n)
653    }
654
655    /// A·Bᵀ strided-batched GEMM on a specific device ordinal. As with the
656    /// broadcast variant, the ordinal's context is expected to be bound on the
657    /// calling thread (multi-GPU worker or single-device dispatcher).
658    #[inline]
659    pub(crate) fn gemm_abt_strided_batched_cuda(
660        ordinal: usize,
661        a: ArrayView3<'_, f64>,
662        b: ArrayView3<'_, f64>,
663    ) -> Option<Array3<f64>> {
664        let (batch, m, k) = a.dim();
665        let (batch_b, n, k_b) = b.dim();
666        if batch == 0 || m == 0 || n == 0 || k == 0 || batch != batch_b || k != k_b {
667            return None;
668        }
669        let (stream, blas) = stream_and_blas_for(ordinal)?;
670        let a_col = to_col_major_batch(a);
671        let b_col = to_col_major_batch(b);
672        let a_dev = stream.clone_htod(&a_col).ok()?;
673        let b_dev = stream.clone_htod(&b_col).ok()?;
674        let mut out_dev = stream
675            .alloc_zeros::<f64>(batch.checked_mul(m)?.checked_mul(n)?)
676            .ok()?;
677        let cfg = StridedBatchedConfig::<f64> {
678            gemm: GemmConfig::<f64> {
679                transa: cublasOperation_t::CUBLAS_OP_N,
680                transb: cublasOperation_t::CUBLAS_OP_T,
681                m: to_i32(m)?,
682                n: to_i32(n)?,
683                k: to_i32(k)?,
684                alpha: 1.0,
685                lda: to_i32(m)?,
686                ldb: to_i32(n)?,
687                beta: 0.0,
688                ldc: to_i32(m)?,
689            },
690            batch_size: to_i32(batch)?,
691            stride_a: i64::try_from(m.checked_mul(k)?).ok()?,
692            stride_b: i64::try_from(n.checked_mul(k)?).ok()?,
693            stride_c: i64::try_from(m.checked_mul(n)?).ok()?,
694        };
695        // SAFETY: each batch item is column-major. The B batch stores n×k
696        // matrices and cuBLAS transposes each to k×n before multiplication.
697        unsafe { blas.gemm_strided_batched(cfg, &a_dev, &b_dev, &mut out_dev) }.ok()?;
698        let out_col = stream.clone_dtoh(&out_dev).ok()?;
699        from_col_major_batch(&out_col, batch, m, n)
700    }
701
702    #[inline]
703    pub(crate) fn gemv_cuda(
704        runtime: &GpuRuntime,
705        a: ArrayView2<'_, f64>,
706        v: ArrayView1<'_, f64>,
707        trans_a: bool,
708    ) -> Option<Array1<f64>> {
709        let (rows, cols) = a.dim();
710        let out_len = if trans_a { cols } else { rows };
711        let needed = if trans_a { rows } else { cols };
712        if out_len == 0 || needed == 0 || v.len() != needed {
713            return None;
714        }
715        let (stream, blas) = stream_and_blas(runtime)?;
716        let a_col = to_col_major(&a);
717        let a_dev = stream.clone_htod(&*a_col).ok()?;
718        let v_host = vector_values(v);
719        let v_dev = stream.clone_htod(&v_host).ok()?;
720        let mut out_dev = stream.alloc_zeros::<f64>(out_len).ok()?;
721        let cfg = GemvConfig::<f64> {
722            trans: if trans_a {
723                cublasOperation_t::CUBLAS_OP_T
724            } else {
725                cublasOperation_t::CUBLAS_OP_N
726            },
727            m: to_i32(rows)?,
728            n: to_i32(cols)?,
729            alpha: 1.0,
730            lda: to_i32(rows)?,
731            incx: 1,
732            beta: 0.0,
733            incy: 1,
734        };
735        // SAFETY: dimensions and vector length match the cuBLAS GEMV contract.
736        unsafe { blas.gemv(cfg, &a_dev, &v_dev, &mut out_dev) }.ok()?;
737        Some(Array1::from_vec(stream.clone_dtoh(&out_dev).ok()?))
738    }
739
740    #[inline]
741    pub fn xt_diag_x_cuda(
742        runtime: &GpuRuntime,
743        x: ArrayView2<'_, f64>,
744        w: ArrayView1<'_, f64>,
745    ) -> Option<Array2<f64>> {
746        let (rows, cols) = x.dim();
747        if rows == 0 || cols == 0 || rows != w.len() {
748            return None;
749        }
750        weighted_crossprod(runtime, x, w, x)
751    }
752
753    #[inline]
754    pub(crate) fn xt_diag_x_on_ordinal_cuda(
755        ordinal: usize,
756        x: ArrayView2<'_, f64>,
757        w: ArrayView1<'_, f64>,
758    ) -> Option<Array2<f64>> {
759        let (rows, cols) = x.dim();
760        if rows == 0 || cols == 0 || rows != w.len() {
761            return None;
762        }
763        weighted_crossprod_for(ordinal, x, w, x)
764    }
765
766    #[inline]
767    pub fn xt_diag_y_cuda(
768        runtime: &GpuRuntime,
769        x: ArrayView2<'_, f64>,
770        w: ArrayView1<'_, f64>,
771        y: ArrayView2<'_, f64>,
772    ) -> Option<Array2<f64>> {
773        weighted_crossprod(runtime, x, w, y)
774    }
775
776    #[inline]
777    pub(crate) fn joint_hessian_2x2_cuda(
778        runtime: &GpuRuntime,
779        x_a: ArrayView2<'_, f64>,
780        x_b: ArrayView2<'_, f64>,
781        w_aa: ArrayView1<'_, f64>,
782        w_ab: ArrayView1<'_, f64>,
783        w_bb: ArrayView1<'_, f64>,
784    ) -> Option<Array2<f64>> {
785        let (rows, pa) = x_a.dim();
786        let (rows_b, pb) = x_b.dim();
787        let total = pa.checked_add(pb)?;
788        if rows == 0
789            || total == 0
790            || rows != rows_b
791            || rows != w_aa.len()
792            || rows != w_ab.len()
793            || rows != w_bb.len()
794        {
795            return None;
796        }
797
798        let mut out = Array2::<f64>::zeros((total, total));
799        if pa > 0 {
800            let aa = weighted_crossprod(runtime, x_a, w_aa, x_a)?;
801            assign_block(&mut out, 0, 0, &aa);
802        }
803        if pa > 0 && pb > 0 {
804            let ab = weighted_crossprod(runtime, x_a, w_ab, x_b)?;
805            assign_block(&mut out, 0, pa, &ab);
806        }
807        if pb > 0 {
808            let bb = weighted_crossprod(runtime, x_b, w_bb, x_b)?;
809            assign_block(&mut out, pa, pa, &bb);
810        }
811        mirror_upper_to_lower(&mut out);
812        Some(out)
813    }
814
815    #[inline]
816    pub(crate) fn trsm_cuda(
817        runtime: &GpuRuntime,
818        triangular: ArrayView2<'_, f64>,
819        rhs: ArrayView2<'_, f64>,
820        upper: bool,
821    ) -> Option<Array2<f64>> {
822        let (n, n2) = triangular.dim();
823        if n == 0 || n != n2 || rhs.nrows() != n {
824            return None;
825        }
826        let nrhs = rhs.ncols();
827        let (stream, blas) = stream_and_blas(runtime)?;
828        let tri_col = to_col_major(&triangular);
829        let rhs_col = to_col_major(&rhs);
830        let tri_dev = stream.clone_htod(&*tri_col).ok()?;
831        let mut rhs_dev = stream.clone_htod(&*rhs_col).ok()?;
832        let alpha = 1.0_f64;
833        let handle = *blas.handle();
834        {
835            let (tri_ptr, _tri_record) = tri_dev.device_ptr(&stream);
836            let (rhs_ptr, _rhs_record) = rhs_dev.device_ptr_mut(&stream);
837            // SAFETY: triangular is n×n and rhs is n×nrhs in column-major device
838            // buffers. cublasDtrsm overwrites rhs with A^{-1} rhs.
839            let status = unsafe {
840                cudarc::cublas::sys::cublasDtrsm_v2(
841                    handle,
842                    cublasSideMode_t::CUBLAS_SIDE_LEFT,
843                    if upper {
844                        cublasFillMode_t::CUBLAS_FILL_MODE_UPPER
845                    } else {
846                        cublasFillMode_t::CUBLAS_FILL_MODE_LOWER
847                    },
848                    cublasOperation_t::CUBLAS_OP_N,
849                    cublasDiagType_t::CUBLAS_DIAG_NON_UNIT,
850                    to_i32(n)?,
851                    to_i32(nrhs)?,
852                    &alpha,
853                    tri_ptr as *const f64,
854                    to_i32(n)?,
855                    rhs_ptr as *mut f64,
856                    to_i32(n)?,
857                )
858            };
859            if status != cublasStatus_t::CUBLAS_STATUS_SUCCESS {
860                return None;
861            }
862        };
863        let out_col = stream.clone_dtoh(&rhs_dev).ok()?;
864        from_col_major(&out_col, n, nrhs)
865    }
866}
867
868#[cfg(target_os = "linux")]
869pub(crate) use cuda_impl::{
870    ResidentWeightedGram, gemm_abt_strided_batched_cuda, gemm_broadcast_b_batched_cuda, gemm_cuda,
871    gemm_on_ordinal_cuda, gemv_cuda, joint_hessian_2x2_cuda, trsm_cuda, xt_diag_x_on_ordinal_cuda,
872};
873// Cross-crate cuBLAS entry points (gam-models BMS Hessian paths call these
874// directly): the #1521 carve promoted the sibling solver/GEMM entry points to
875// `pub` but left these two `pub(crate)`, so they were invisible to their
876// out-of-crate callers (E0603) on the linux cuda build that the workspace
877// `cargo check` config does not exercise. Promote to match the rest of the
878// cross-crate cuBLAS surface.
879#[cfg(target_os = "linux")]
880pub use cuda_impl::{xt_diag_x_cuda, xt_diag_y_cuda};