Skip to main content

gam_gpu/
solver.rs

1//! cuSOLVER-backed dense solver kernels for the GPU HAL.
2//!
3//! This module owns CUDA solver functionality that is shared by GPU linear
4//! algebra dispatch and higher-level solver code. CPU solves do not live behind
5//! these entry points: unavailable CUDA support is reported as an error.
6
7use ndarray::{Array2, ArrayView2};
8
9pub fn solver_backend_status() -> super::CudaBackendStatus {
10    super::cuda_backend_status()
11}
12
13/// Outcome reported by [`iterative_refinement_cholesky_solve`].
14#[derive(Clone, Debug)]
15pub struct RefinementOutcome {
16    /// Solution vector `x` satisfying `A x ≈ b`.
17    pub solution: ndarray::Array1<f64>,
18    /// `‖r‖ / ‖b‖` where `r = b − A x` after the last refinement step
19    /// (or after the initial fp32 solve when no steps were taken).
20    pub relative_residual: f64,
21    /// Precision path used for the factorization.
22    pub used_fp32_factor: bool,
23    /// Number of refinement steps taken (0 means only the initial solve ran).
24    pub refinement_steps: usize,
25}
26
27#[cfg(target_os = "linux")]
28mod cuda {
29    use crate::driver::{from_col_major, to_col_major};
30    use cudarc::cublas::sys as cublas_sys;
31    use cudarc::cublas::{CudaBlas, Gemv, GemvConfig};
32    use cudarc::cusolver::{DnHandle, sys as cusolver_sys};
33    use cudarc::driver::{CudaContext, CudaSlice, DevicePtr, DevicePtrMut};
34    use faer::MatRef;
35    use gam_linalg::faer_ndarray::cholesky_factor_logdet;
36    use ndarray::{Array2, ArrayView2};
37
38    pub(super) fn cholesky_solve(
39        hessian: ArrayView2<'_, f64>,
40        rhs: ArrayView2<'_, f64>,
41    ) -> Result<(Array2<f64>, f64), String> {
42        let (_, stream) = context_and_stream()?;
43        let (p, p2) = hessian.dim();
44        if p == 0 || p != p2 || rhs.nrows() != p {
45            return Err("Cholesky solve dimension mismatch".to_string());
46        }
47        let nrhs = rhs.ncols();
48        let solver = DnHandle::new(stream.clone()).map_err(|e| format!("cusolver init: {e}"))?;
49        let h_col = to_col_major(&hessian);
50        let rhs_col = to_col_major(&rhs);
51        let mut h_dev = pinned_htod(&stream, &h_col)?;
52        let mut rhs_dev = pinned_htod(&stream, &rhs_col)?;
53        potrf_in_place(&solver, &stream, p, &mut h_dev)?;
54        potrs_in_place(&solver, &stream, p, nrhs, &h_dev, &mut rhs_dev)?;
55        let factor_col = stream
56            .clone_dtoh(&h_dev)
57            .map_err(|e| format!("download Cholesky factor: {e}"))?;
58        let out_col = stream
59            .clone_dtoh(&rhs_dev)
60            .map_err(|e| format!("download solution: {e}"))?;
61        let solved =
62            from_col_major(&out_col, p, nrhs).ok_or("solution layout conversion failed")?;
63        Ok((solved, cholesky_logdet_from_col_major(&factor_col, p)))
64    }
65
66    /// fp64 log-determinant of an SPD matrix via POTRF only.
67    ///
68    /// This is [`cholesky_solve`] stripped of the triangular solve (POTRS) and
69    /// the solution download/layout conversion: the log-determinant depends
70    /// solely on the Cholesky factor's diagonal, so when a caller already holds
71    /// the solution (e.g. from fp32 + iterative refinement) and needs *only* an
72    /// accurate fp64 logdet, doing a full solve here would burn an O(p²·nrhs)
73    /// POTRS plus a host round-trip on a solution that is immediately discarded.
74    pub(super) fn cholesky_logdet(hessian: ArrayView2<'_, f64>) -> Result<f64, String> {
75        let (_, stream) = context_and_stream()?;
76        let (p, p2) = hessian.dim();
77        if p == 0 || p != p2 {
78            return Err("Cholesky logdet dimension mismatch".to_string());
79        }
80        let solver = DnHandle::new(stream.clone()).map_err(|e| format!("cusolver init: {e}"))?;
81        let h_col = to_col_major(&hessian);
82        let mut h_dev = pinned_htod(&stream, &h_col)?;
83        potrf_in_place(&solver, &stream, p, &mut h_dev)?;
84        let factor_col = stream
85            .clone_dtoh(&h_dev)
86            .map_err(|e| format!("download Cholesky factor: {e}"))?;
87        Ok(cholesky_logdet_from_col_major(&factor_col, p))
88    }
89
90    pub(super) fn cholesky_lower(hessian: ArrayView2<'_, f64>) -> Result<Array2<f64>, String> {
91        let (_, stream) = context_and_stream()?;
92        cholesky_lower_on_stream(hessian, &stream)
93    }
94
95    pub(super) fn cholesky_lower_on_ordinal(
96        ordinal: usize,
97        hessian: ArrayView2<'_, f64>,
98    ) -> Result<Array2<f64>, String> {
99        let (_, stream) = context_and_stream_for(ordinal)?;
100        cholesky_lower_on_stream(hessian, &stream)
101    }
102
103    fn cholesky_lower_on_stream(
104        hessian: ArrayView2<'_, f64>,
105        stream: &std::sync::Arc<cudarc::driver::CudaStream>,
106    ) -> Result<Array2<f64>, String> {
107        let (p, p2) = hessian.dim();
108        if p == 0 || p != p2 {
109            return Err("Cholesky factorization dimension mismatch".to_string());
110        }
111        let solver = DnHandle::new(stream.clone()).map_err(|e| format!("cusolver init: {e}"))?;
112        let h_col = to_col_major(&hessian);
113        let mut h_dev = pinned_htod(&stream, &h_col)?;
114        potrf_in_place(&solver, &stream, p, &mut h_dev)?;
115        let factor_col = stream
116            .clone_dtoh(&h_dev)
117            .map_err(|e| format!("download Cholesky factor: {e}"))?;
118        let mut lower =
119            from_col_major(&factor_col, p, p).ok_or("factor layout conversion failed")?;
120        for row in 0..p {
121            for col in (row + 1)..p {
122                lower[[row, col]] = 0.0;
123            }
124        }
125        Ok(lower)
126    }
127
128    // -----------------------------------------------------------------------
129    // Precision-generic Cholesky scaffold
130    //
131    // POTRF / POTRS host scaffolds are identical across single and double
132    // precision apart from the cuSOLVER symbol called and the device pointer
133    // type. `CholScalar` selects those per-precision pieces so the host-side
134    // allocation / info-handling / error-formatting logic lives once. The
135    // `Dpotr*` (f64) and `Spotr*` (f32) entry points below are thin wrappers
136    // over the generic helpers, preserving their public signatures byte for
137    // byte.
138    // -----------------------------------------------------------------------
139
140    /// cuSOLVER scalar abstraction: selects the precision-specific POTRF/POTRS
141    /// symbols and the precision tag used in deferred-info error messages.
142    ///
143    /// The FFI into cuSOLVER lives inside the trait methods' bodies (in `unsafe`
144    /// blocks), so the trait and its methods are safe to call: each impl wires
145    /// its method bodies to the cuSOLVER entry points whose pointer arguments
146    /// match `Self` (e.g. `cusolverDnDpotrf` for `f64`). Implementors must keep
147    /// that pairing consistent — the device pointer passed in is typed `*mut
148    /// Self` / `*const Self`, so a mismatched symbol would hand cuSOLVER a
149    /// wrongly-typed buffer.
150    pub(crate) trait CholScalar:
151        cudarc::driver::DeviceRepr + cudarc::driver::ValidAsZeroBits + Copy
152    {
153        /// cuSOLVER `*potrf_bufferSize`: `(handle, uplo, n, A, lda, *lwork)`.
154        ///
155        /// `a` is a live `n*n` column-major device buffer of type `Self`,
156        /// `lwork` is a host out-param. The unsafe FFI call is contained in the
157        /// method body.
158        fn potrf_buffer_size(
159            handle: cusolver_sys::cusolverDnHandle_t,
160            uplo: cusolver_sys::cublasFillMode_t,
161            n: i32,
162            a: *mut Self,
163            lda: i32,
164            lwork: *mut i32,
165        ) -> cusolver_sys::cusolverStatus_t;
166        /// cuSOLVER `*potrf`: `(handle, uplo, n, A, lda, work, lwork, info)`.
167        ///
168        /// Pointer args must reference live device buffers of the documented
169        /// shape; the unsafe FFI call is contained in the method body.
170        fn potrf(
171            handle: cusolver_sys::cusolverDnHandle_t,
172            uplo: cusolver_sys::cublasFillMode_t,
173            n: i32,
174            a: *mut Self,
175            lda: i32,
176            work: *mut Self,
177            lwork: i32,
178            info: *mut i32,
179        ) -> cusolver_sys::cusolverStatus_t;
180        /// cuSOLVER `*potrs`: `(handle, uplo, n, nrhs, A, lda, B, ldb, info)`.
181        ///
182        /// Pointer args must reference live device buffers of the documented
183        /// shape; the unsafe FFI call is contained in the method body.
184        fn potrs(
185            handle: cusolver_sys::cusolverDnHandle_t,
186            uplo: cusolver_sys::cublasFillMode_t,
187            n: i32,
188            nrhs: i32,
189            a: *const Self,
190            lda: i32,
191            b: *mut Self,
192            ldb: i32,
193            info: *mut i32,
194        ) -> cusolver_sys::cusolverStatus_t;
195        /// Symbol name fragment for error messages (e.g. `"Dpotrf"`).
196        const POTRF_NAME: &'static str;
197        const POTRS_NAME: &'static str;
198        /// Trailing clause appended to a POTRF "not SPD" error (e.g.
199        /// `" (matrix not SPD at f32)"`); empty for f64.
200        const POTRF_FAIL_SUFFIX: &'static str;
201    }
202
203    impl CholScalar for f64 {
204        fn potrf_buffer_size(
205            handle: cusolver_sys::cusolverDnHandle_t,
206            uplo: cusolver_sys::cublasFillMode_t,
207            n: i32,
208            a: *mut f64,
209            lda: i32,
210            lwork: *mut i32,
211        ) -> cusolver_sys::cusolverStatus_t {
212            // SAFETY: caller guarantees `a` is a live n*n column-major f64 device
213            // buffer and `lwork` is a valid host out-param; symbol matches f64.
214            unsafe { cusolver_sys::cusolverDnDpotrf_bufferSize(handle, uplo, n, a, lda, lwork) }
215        }
216        fn potrf(
217            handle: cusolver_sys::cusolverDnHandle_t,
218            uplo: cusolver_sys::cublasFillMode_t,
219            n: i32,
220            a: *mut f64,
221            lda: i32,
222            work: *mut f64,
223            lwork: i32,
224            info: *mut i32,
225        ) -> cusolver_sys::cusolverStatus_t {
226            // SAFETY: caller guarantees `a` is a live n*n column-major f64 buffer,
227            // `work` was sized by potrf_buffer_size, `info` is a 1-element i32
228            // device buffer; symbol matches f64.
229            unsafe { cusolver_sys::cusolverDnDpotrf(handle, uplo, n, a, lda, work, lwork, info) }
230        }
231        fn potrs(
232            handle: cusolver_sys::cusolverDnHandle_t,
233            uplo: cusolver_sys::cublasFillMode_t,
234            n: i32,
235            nrhs: i32,
236            a: *const f64,
237            lda: i32,
238            b: *mut f64,
239            ldb: i32,
240            info: *mut i32,
241        ) -> cusolver_sys::cusolverStatus_t {
242            // SAFETY: caller guarantees `a` is a live n*n f64 Cholesky factor,
243            // `b` is n*nrhs column-major f64, `info` is a 1-element i32 device
244            // buffer; symbol matches f64.
245            unsafe { cusolver_sys::cusolverDnDpotrs(handle, uplo, n, nrhs, a, lda, b, ldb, info) }
246        }
247        const POTRF_NAME: &'static str = "Dpotrf";
248        const POTRS_NAME: &'static str = "Dpotrs";
249        const POTRF_FAIL_SUFFIX: &'static str = "";
250    }
251
252    impl CholScalar for f32 {
253        fn potrf_buffer_size(
254            handle: cusolver_sys::cusolverDnHandle_t,
255            uplo: cusolver_sys::cublasFillMode_t,
256            n: i32,
257            a: *mut f32,
258            lda: i32,
259            lwork: *mut i32,
260        ) -> cusolver_sys::cusolverStatus_t {
261            // SAFETY: caller guarantees `a` is a live n*n column-major f32 device
262            // buffer and `lwork` is a valid host out-param; symbol matches f32.
263            unsafe { cusolver_sys::cusolverDnSpotrf_bufferSize(handle, uplo, n, a, lda, lwork) }
264        }
265        fn potrf(
266            handle: cusolver_sys::cusolverDnHandle_t,
267            uplo: cusolver_sys::cublasFillMode_t,
268            n: i32,
269            a: *mut f32,
270            lda: i32,
271            work: *mut f32,
272            lwork: i32,
273            info: *mut i32,
274        ) -> cusolver_sys::cusolverStatus_t {
275            // SAFETY: caller guarantees `a` is a live n*n column-major f32 buffer,
276            // `work` was sized by potrf_buffer_size, `info` is a 1-element i32
277            // device buffer; symbol matches f32.
278            unsafe { cusolver_sys::cusolverDnSpotrf(handle, uplo, n, a, lda, work, lwork, info) }
279        }
280        fn potrs(
281            handle: cusolver_sys::cusolverDnHandle_t,
282            uplo: cusolver_sys::cublasFillMode_t,
283            n: i32,
284            nrhs: i32,
285            a: *const f32,
286            lda: i32,
287            b: *mut f32,
288            ldb: i32,
289            info: *mut i32,
290        ) -> cusolver_sys::cusolverStatus_t {
291            // SAFETY: caller guarantees `a` is a live n*n f32 Cholesky factor,
292            // `b` is n*nrhs column-major f32, `info` is a 1-element i32 device
293            // buffer; symbol matches f32.
294            unsafe { cusolver_sys::cusolverDnSpotrs(handle, uplo, n, nrhs, a, lda, b, ldb, info) }
295        }
296        const POTRF_NAME: &'static str = "Spotrf";
297        const POTRS_NAME: &'static str = "Spotrs";
298        const POTRF_FAIL_SUFFIX: &'static str = " (matrix not SPD at f32)";
299    }
300
301    /// Query the cuSOLVER POTRF workspace size (element count) for a p×p
302    /// matrix at precision `T`. Allocates a temporary p×p dummy buffer for the
303    /// query.
304    fn potrf_bufsize_generic<T: CholScalar>(
305        solver: &DnHandle,
306        stream: &std::sync::Arc<cudarc::driver::CudaStream>,
307        p: usize,
308    ) -> Result<usize, String> {
309        let p_i = to_i32(p)?;
310        let uplo = cusolver_sys::cublasFillMode_t::CUBLAS_FILL_MODE_LOWER;
311        let mut lwork = 0_i32;
312        let mut dummy = stream
313            .alloc_zeros::<T>(p.checked_mul(p).ok_or("p² overflow in lwork query")?)
314            .map_err(|e| format!("cuda alloc dummy for lwork query: {e}"))?;
315        {
316            let (ptr, _rec) = dummy.device_ptr_mut(stream);
317            // dummy is a live p*p device buffer of type T, lwork is a host i32;
318            // the unsafe cuSOLVER FFI is contained in T::potrf_buffer_size.
319            let status =
320                T::potrf_buffer_size(solver.cu(), uplo, p_i, ptr as *mut T, p_i, &mut lwork);
321            check_cusolver(status, "cusolverDn*potrf_bufferSize")?;
322        }
323        usize::try_from(lwork).map_err(|_| "negative potrf lwork".to_string())
324    }
325
326    /// Factor a p×p SPD device buffer in-place (lower-triangular Cholesky) at
327    /// precision `T`, querying and allocating its own workspace. Returns `Err`
328    /// if the matrix is singular/indefinite at precision `T`.
329    ///
330    /// This is the single-matrix POTRF core shared across the GPU layer:
331    /// `solver.rs`'s `potrf_in_place`/`spotrf_in_place` and `linalg.rs`'s
332    /// `potrf_lower_in_place` all route through it (the latter mapping the
333    /// `Result` to its `Option` contract at the boundary). The batched POTRF
334    /// (`cusolverDnDpotrfBatched`) in `linalg.rs` is intentionally separate.
335    pub(crate) fn potrf_in_place_generic<T: CholScalar>(
336        solver: &DnHandle,
337        stream: &std::sync::Arc<cudarc::driver::CudaStream>,
338        p: usize,
339        a: &mut CudaSlice<T>,
340    ) -> Result<(), String> {
341        let p_i = to_i32(p)?;
342        let lwork = potrf_bufsize_generic::<T>(solver, stream, p)?;
343        let lwork_i = i32::try_from(lwork).map_err(|_| "negative potrf workspace".to_string())?;
344        let mut workspace = stream
345            .alloc_zeros::<T>(lwork.max(1))
346            .map_err(|e| format!("cuda alloc potrf workspace: {e}"))?;
347        let mut info = stream
348            .alloc_zeros::<i32>(1)
349            .map_err(|e| format!("cuda alloc potrf info: {e}"))?;
350        let uplo = cusolver_sys::cublasFillMode_t::CUBLAS_FILL_MODE_LOWER;
351        {
352            let (a_ptr, _a_rec) = a.device_ptr_mut(stream);
353            let (work_ptr, _work_rec) = workspace.device_ptr_mut(stream);
354            let (info_ptr, _info_rec) = info.device_ptr_mut(stream);
355            // a is p*p col-major T, workspace was sized by T::potrf_buffer_size,
356            // info is a 1-element i32 device buffer; the unsafe cuSOLVER FFI is
357            // contained in T::potrf.
358            let status = T::potrf(
359                solver.cu(),
360                uplo,
361                p_i,
362                a_ptr as *mut T,
363                p_i,
364                work_ptr as *mut T,
365                lwork_i,
366                info_ptr as *mut i32,
367            );
368            check_cusolver(status, "cusolverDn*potrf")?;
369        }
370        let info_host = stream
371            .clone_dtoh(&info)
372            .map_err(|e| format!("download potrf info: {e}"))?;
373        if info_host[0] == 0 {
374            Ok(())
375        } else {
376            Err(format!(
377                "cusolverDn{} returned info={}{}",
378                T::POTRF_NAME,
379                info_host[0],
380                T::POTRF_FAIL_SUFFIX
381            ))
382        }
383    }
384
385    /// Triangular solve using a pre-factored Cholesky lower-triangle at
386    /// precision `T`. Solves `A · x = rhs` in-place into `rhs` (column-major,
387    /// p × nrhs), allocating and downloading its own info scalar.
388    fn potrs_in_place_generic<T: CholScalar>(
389        solver: &DnHandle,
390        stream: &std::sync::Arc<cudarc::driver::CudaStream>,
391        p: usize,
392        nrhs: usize,
393        factor: &CudaSlice<T>,
394        rhs: &mut CudaSlice<T>,
395    ) -> Result<(), String> {
396        let p_i = to_i32(p)?;
397        let nrhs_i = to_i32(nrhs)?;
398        let uplo = cusolver_sys::cublasFillMode_t::CUBLAS_FILL_MODE_LOWER;
399        let mut info = stream
400            .alloc_zeros::<i32>(1)
401            .map_err(|e| format!("cuda alloc potrs info: {e}"))?;
402        {
403            let (f_ptr, _f_rec) = factor.device_ptr(stream);
404            let (r_ptr, _r_rec) = rhs.device_ptr_mut(stream);
405            let (info_ptr, _info_rec) = info.device_ptr_mut(stream);
406            // factor is a p*p lower-triangular T from potrf, rhs is p*nrhs
407            // col-major T, info is a 1-element i32 device buffer; leading dims
408            // match column-major p_i. The unsafe cuSOLVER FFI is contained in
409            // T::potrs.
410            let status = T::potrs(
411                solver.cu(),
412                uplo,
413                p_i,
414                nrhs_i,
415                f_ptr as *const T,
416                p_i,
417                r_ptr as *mut T,
418                p_i,
419                info_ptr as *mut i32,
420            );
421            check_cusolver(status, "cusolverDn*potrs")?;
422        }
423        let info_host = stream
424            .clone_dtoh(&info)
425            .map_err(|e| format!("download potrs info: {e}"))?;
426        if info_host[0] == 0 {
427            Ok(())
428        } else {
429            Err(format!(
430                "cusolverDn{} returned info={}",
431                T::POTRS_NAME,
432                info_host[0]
433            ))
434        }
435    }
436
437    // -----------------------------------------------------------------------
438    // fp32 entry points (thin wrappers over the precision-generic scaffold)
439    // -----------------------------------------------------------------------
440
441    /// Factor a p×p symmetric positive-definite f32 device buffer in-place
442    /// (lower-triangular Cholesky). Returns `Err` if the matrix is
443    /// singular/indefinite.
444    fn spotrf_in_place(
445        solver: &DnHandle,
446        stream: &std::sync::Arc<cudarc::driver::CudaStream>,
447        p: usize,
448        a: &mut CudaSlice<f32>,
449    ) -> Result<(), String> {
450        potrf_in_place_generic::<f32>(solver, stream, p, a)
451    }
452
453    /// Triangular solve using a pre-factored fp32 Cholesky lower-triangle.
454    /// Solves `A · x = rhs` in-place into `rhs` (column-major, p × nrhs).
455    fn spotrs_in_place(
456        solver: &DnHandle,
457        stream: &std::sync::Arc<cudarc::driver::CudaStream>,
458        p: usize,
459        nrhs: usize,
460        factor: &CudaSlice<f32>,
461        rhs: &mut CudaSlice<f32>,
462    ) -> Result<(), String> {
463        potrs_in_place_generic::<f32>(solver, stream, p, nrhs, factor, rhs)
464    }
465
466    // -----------------------------------------------------------------------
467    // fp64 DGEMV residual: r = b − A·x in double precision
468    // -----------------------------------------------------------------------
469
470    /// Compute `r = b − A·x` in fp64 where A is p×p and x, b, r are length p.
471    ///
472    /// Overwrites the output buffer `r_dev` with the residual. Uses
473    /// `cublasDgemv` (CUBLAS_OP_N): `r = 1·A·x + 0·0 = A·x`, then the host
474    /// subtracts from b. Because p is small here (the policy gates on p ≥ 64
475    /// and the Newton system is p×p), downloading the p-vector for the host
476    /// subtract is cheap relative to the GEMV.
477    fn residual_norm_and_vec(
478        blas: &CudaBlas,
479        stream: &std::sync::Arc<cudarc::driver::CudaStream>,
480        p: usize,
481        a_dev: &CudaSlice<f64>,
482        x_dev: &CudaSlice<f64>,
483        b_host: &[f64],
484    ) -> Result<(Vec<f64>, f64), String> {
485        let p_i = to_i32(p)?;
486        // ax_dev = A · x
487        let mut ax_dev = stream
488            .alloc_zeros::<f64>(p)
489            .map_err(|e| format!("alloc ax: {e}"))?;
490        {
491            let cfg = GemvConfig::<f64> {
492                trans: cublas_sys::cublasOperation_t::CUBLAS_OP_N,
493                m: p_i,
494                n: p_i,
495                alpha: 1.0_f64,
496                lda: p_i,
497                incx: 1,
498                beta: 0.0_f64,
499                incy: 1,
500            };
501            // SAFETY: cuBLAS Dgemv; a_dev is p*p col-major f64, x_dev is
502            // length-p f64, ax_dev is length-p output; all on the same stream.
503            unsafe { blas.gemv(cfg, a_dev, x_dev, &mut ax_dev) }
504                .map_err(|e| format!("cublasDgemv for residual: {e}"))?;
505        }
506        let ax_host = stream
507            .clone_dtoh(&ax_dev)
508            .map_err(|e| format!("download A·x: {e}"))?;
509        // r = b − A·x  (host subtract; p is small)
510        let r: Vec<f64> = b_host
511            .iter()
512            .zip(ax_host.iter())
513            .map(|(bi, axi)| bi - axi)
514            .collect();
515        let norm_r = r.iter().map(|v| v * v).sum::<f64>().sqrt();
516        Ok((r, norm_r))
517    }
518
519    // -----------------------------------------------------------------------
520    // Iterative refinement: fp32 factor → fp32 solve → fp64 residual loop
521    // -----------------------------------------------------------------------
522
523    /// Solve `A x = b` using an fp32 Cholesky factorization with up to
524    /// `max_steps` fp64-residual iterative refinement corrections.
525    ///
526    /// # Algorithm
527    ///
528    /// 1. Cast `A` (f64) to f32 on device. Factor in fp32 (POTRF).
529    /// 2. Cast `b` (f64) to f32. Solve `A x = b` in fp32 (POTRS). Lift `x`
530    ///    to f64.
531    /// 3. Loop up to `max_steps`:
532    ///    a. `r = b − A·x` accumulated in fp64 (cuBLAS Dgemv).
533    ///    b. `‖r‖ / ‖b‖ ≤ tol` → converged, break.
534    ///    c. Residual did not drop below previous step → bail, return `Err`.
535    ///    d. Cast `r` to f32. Solve `A e = r` in fp32. `x += e` (f64).
536    /// 4. Return `(x, ‖r‖/‖b‖, refinement_steps)`.
537    ///
538    /// Returns `Err` when the fp32 POTRF fails (not SPD at f32) or when the
539    /// residual does not decrease monotonically (κ(A)·u_f32 ≥ 1 regime).
540    /// Callers should fall back to fp64 POTRF on `Err`.
541    pub(super) fn iterative_refinement_solve_impl(
542        hessian: ArrayView2<'_, f64>,
543        rhs: &[f64],
544    ) -> Result<super::RefinementOutcome, String> {
545        use crate::policy::GpuDispatchPolicy;
546        let (p, p2) = hessian.dim();
547        if p == 0 || p != p2 || rhs.len() != p {
548            return Err("iterative_refinement_solve: dimension mismatch".to_string());
549        }
550        let max_steps = GpuDispatchPolicy::REFINEMENT_MAX_STEPS;
551        let tol = GpuDispatchPolicy::REFINEMENT_TOL;
552
553        let (_, stream) = context_and_stream()?;
554        let solver = DnHandle::new(stream.clone()).map_err(|e| format!("cusolver init: {e}"))?;
555        let blas = CudaBlas::new(stream.clone()).map_err(|e| format!("cublas init: {e}"))?;
556
557        // Upload fp64 hessian for residual GEMV.
558        let h_col_f64 = to_col_major(&hessian);
559        let a_dev_f64 = pinned_htod(&stream, &h_col_f64)?;
560
561        // Cast A to f32 and upload.
562        let h_col_f32: Vec<f32> = h_col_f64.iter().map(|&v| v as f32).collect();
563        let mut a_dev_f32 =
564            pinned_htod(&stream, &h_col_f32).map_err(|e| format!("upload f32 A: {e}"))?;
565
566        // fp32 POTRF — returns Err if A is not SPD at f32 precision.
567        spotrf_in_place(&solver, &stream, p, &mut a_dev_f32)?;
568
569        // Cast b to f32 and upload; solve in fp32.
570        let b_f32: Vec<f32> = rhs.iter().map(|&v| v as f32).collect();
571        let mut x_dev_f32 =
572            pinned_htod(&stream, &b_f32).map_err(|e| format!("upload f32 rhs: {e}"))?;
573        spotrs_in_place(&solver, &stream, p, 1, &a_dev_f32, &mut x_dev_f32)?;
574
575        // Lift x to f64.
576        let x_f32 = stream
577            .clone_dtoh(&x_dev_f32)
578            .map_err(|e| format!("download f32 x: {e}"))?;
579        let mut x: Vec<f64> = x_f32.iter().map(|&v| v as f64).collect();
580
581        // Compute ‖b‖ for relative residual.
582        let norm_b = rhs.iter().map(|v| v * v).sum::<f64>().sqrt();
583        let norm_b_safe = if norm_b > 0.0 { norm_b } else { 1.0 };
584
585        let mut x_dev_f64 = pinned_htod(&stream, &x).map_err(|e| format!("upload f64 x: {e}"))?;
586        let (r0, norm_r0) = residual_norm_and_vec(&blas, &stream, p, &a_dev_f64, &x_dev_f64, rhs)?;
587        let mut rel_residual = norm_r0 / norm_b_safe;
588
589        // Early exit: already converged after initial solve.
590        if rel_residual <= tol {
591            return Ok(super::RefinementOutcome {
592                solution: ndarray::Array1::from_vec(x),
593                relative_residual: rel_residual,
594                used_fp32_factor: true,
595                refinement_steps: 0,
596            });
597        }
598
599        let mut r = r0;
600        let mut prev_norm_r = norm_r0;
601        let mut steps_taken = 0_usize;
602
603        for _ in 0..max_steps {
604            // Cast residual to f32, solve A e = r in fp32.
605            let r_f32: Vec<f32> = r.iter().map(|&v| v as f32).collect();
606            let mut e_dev_f32 =
607                pinned_htod(&stream, &r_f32).map_err(|e| format!("upload f32 residual: {e}"))?;
608            spotrs_in_place(&solver, &stream, p, 1, &a_dev_f32, &mut e_dev_f32)?;
609
610            // x += e in f64.
611            let e_f32 = stream
612                .clone_dtoh(&e_dev_f32)
613                .map_err(|e| format!("download f32 e: {e}"))?;
614            for (xi, ei) in x.iter_mut().zip(e_f32.iter()) {
615                *xi += *ei as f64;
616            }
617            steps_taken += 1;
618
619            // Reupload x_dev_f64 and compute new residual.
620            x_dev_f64 = pinned_htod(&stream, &x).map_err(|e| format!("upload refined x: {e}"))?;
621            let (r_new, norm_r_new) =
622                residual_norm_and_vec(&blas, &stream, p, &a_dev_f64, &x_dev_f64, rhs)?;
623            rel_residual = norm_r_new / norm_b_safe;
624
625            // Check monotone decrease. Non-monotone → κ(A)·u ≥ 1.
626            if norm_r_new >= prev_norm_r {
627                return Err(format!(
628                    "iterative refinement: residual not decreasing ({norm_r_new:.3e} ≥ {prev_norm_r:.3e}); \
629                     κ(A)·u_f32 ≥ 1, cannot refine"
630                ));
631            }
632            prev_norm_r = norm_r_new;
633            r = r_new;
634
635            if rel_residual <= tol {
636                break;
637            }
638        }
639
640        Ok(super::RefinementOutcome {
641            solution: ndarray::Array1::from_vec(x),
642            relative_residual: rel_residual,
643            used_fp32_factor: true,
644            refinement_steps: steps_taken,
645        })
646    }
647
648    /// Bind a specific device ordinal's cached context on the calling thread and
649    /// open a fresh stream on it. This is the per-ordinal entry point used by
650    /// multi-GPU fan-out (`crate::pool::scatter_batched` workers) so a
651    /// Cholesky / TRSM can target the device the worker thread owns. The
652    /// primary-device convenience wrapper [`context_and_stream`] calls this with
653    /// the probe-selected ordinal.
654    pub(crate) fn context_and_stream_for(
655        ordinal: usize,
656    ) -> Result<
657        (
658            std::sync::Arc<CudaContext>,
659            std::sync::Arc<cudarc::driver::CudaStream>,
660        ),
661        String,
662    > {
663        let ctx = super::super::device_runtime::cuda_context_for(ordinal)
664            .ok_or_else(|| format!("cuda context for ordinal {ordinal} unavailable"))?;
665        ctx.bind_to_thread()
666            .map_err(|e| format!("cuda context bind_to_thread: {e}"))?;
667        let stream = ctx.new_stream().map_err(|e| format!("cuda stream: {e}"))?;
668        Ok((ctx, stream))
669    }
670
671    pub fn context_and_stream() -> Result<
672        (
673            std::sync::Arc<CudaContext>,
674            std::sync::Arc<cudarc::driver::CudaStream>,
675        ),
676        String,
677    > {
678        // Route through the runtime's cached primary context for the selected
679        // device so every CUDA client in the process (calibration, session,
680        // cuSolver) shares one CUcontext per ordinal. Falling back to
681        // `CudaContext::new(0)` here would fragment driver state across
682        // distinct contexts, defeat memory-pool sharing, and pin work to
683        // ordinal 0 even when the runtime probe chose a different device.
684        let runtime = super::super::device_runtime::GpuRuntime::global()
685            .ok_or_else(|| "cuda runtime unavailable".to_string())?;
686        context_and_stream_for(runtime.selected_device().ordinal)
687    }
688
689    pub fn pinned_htod<T: cudarc::driver::DeviceRepr + cudarc::driver::ValidAsZeroBits + Copy>(
690        stream: &std::sync::Arc<cudarc::driver::CudaStream>,
691        src: &[T],
692    ) -> Result<CudaSlice<T>, String> {
693        // Originally this routine round-tripped the upload through a
694        // `CU_MEMHOSTALLOC_WRITECOMBINED` pinned staging buffer
695        // (`ctx.alloc_pinned`) to enable async DMA. In cudarc 0.19 the
696        // `PinnedHostSlice` returned from `alloc_pinned` carries an event that
697        // its `Drop` impl unconditionally `event.synchronize()`s before freeing
698        // the host mapping — see cudarc-0.19.7 `core.rs::PinnedHostSlice::drop`.
699        // Because the staging buffer goes out of scope at the end of this
700        // function, the host thread blocks here until the H2D copy completes,
701        // immediately defeating the "async" of pinned DMA. The net cost is two
702        // extra driver calls per upload (`cuMemHostAlloc_WC` + `cuMemFreeHost`)
703        // plus a forced stream synchronization, and the workspace ends up
704        // strictly slower than a plain pageable H2D — the driver already
705        // stages pageable copies internally via its own pinned pool, and that
706        // path does not block the issuing host thread for unrelated stream
707        // work. Issue a direct async H2D from the pageable buffer instead.
708        stream.clone_htod(src).map_err(|e| format!("cuda H2D: {e}"))
709    }
710
711    pub fn potrf_in_place(
712        solver: &DnHandle,
713        stream: &std::sync::Arc<cudarc::driver::CudaStream>,
714        p: usize,
715        h: &mut CudaSlice<f64>,
716    ) -> Result<(), String> {
717        potrf_in_place_generic::<f64>(solver, stream, p, h)
718    }
719
720    pub fn potrs_in_place(
721        solver: &DnHandle,
722        stream: &std::sync::Arc<cudarc::driver::CudaStream>,
723        p: usize,
724        nrhs: usize,
725        h: &CudaSlice<f64>,
726        rhs: &mut CudaSlice<f64>,
727    ) -> Result<(), String> {
728        potrs_in_place_generic::<f64>(solver, stream, p, nrhs, h, rhs)
729    }
730
731    /// Query the cuSOLVER POTRF workspace size for a p×p matrix.
732    ///
733    /// Called once at workspace construction to size the persistent workspace
734    /// buffer. Returns the number of f64 elements required.
735    pub fn potrf_query_lwork(
736        solver: &DnHandle,
737        stream: &std::sync::Arc<cudarc::driver::CudaStream>,
738        p: usize,
739    ) -> Result<usize, String> {
740        potrf_bufsize_generic::<f64>(solver, stream, p)
741    }
742
743    /// POTRF factorization using pre-allocated workspace and info buffers.
744    ///
745    /// Does not allocate, does not download `info`. The caller is responsible
746    /// for calling [`check_deferred_potrf_info`] at end-of-fit to confirm no
747    /// factorization failed.
748    ///
749    /// `workspace` must have been allocated with at least `lwork` elements
750    /// (as reported by [`potrf_query_lwork`] at workspace construction).
751    /// `info_dev` is a 1-element device i32 buffer; after a failed
752    /// factorization it holds a positive integer but stays device-resident.
753    pub fn potrf_in_place_reuse(
754        solver: &DnHandle,
755        stream: &std::sync::Arc<cudarc::driver::CudaStream>,
756        p: usize,
757        lwork: i32,
758        h: &mut CudaSlice<f64>,
759        workspace: &mut CudaSlice<f64>,
760        info_dev: &mut CudaSlice<i32>,
761    ) -> Result<(), String> {
762        let p_i = to_i32(p)?;
763        let uplo = cusolver_sys::cublasFillMode_t::CUBLAS_FILL_MODE_LOWER;
764        {
765            let (h_ptr, _h_record) = h.device_ptr_mut(stream);
766            let (work_ptr, _work_record) = workspace.device_ptr_mut(stream);
767            let (info_ptr, _info_record) = info_dev.device_ptr_mut(stream);
768            // SAFETY: cuSOLVER potrf; h is p*p col-major, workspace was sized
769            // by potrf_query_lwork, info_dev is a pre-allocated 1-element i32
770            // device buffer. All buffers are live on the same stream.
771            let status = unsafe {
772                cusolver_sys::cusolverDnDpotrf(
773                    solver.cu(),
774                    uplo,
775                    p_i,
776                    h_ptr as *mut f64,
777                    p_i,
778                    work_ptr as *mut f64,
779                    lwork,
780                    info_ptr as *mut i32,
781                )
782            };
783            check_cusolver(status, "cusolverDnDpotrf")?;
784        }
785        Ok(())
786    }
787
788    /// POTRS triangular solve using a pre-allocated info buffer.
789    ///
790    /// Does not allocate, does not download `info`. The caller is responsible
791    /// for calling [`check_deferred_potrs_info`] at end-of-fit.
792    pub fn potrs_in_place_reuse(
793        solver: &DnHandle,
794        stream: &std::sync::Arc<cudarc::driver::CudaStream>,
795        p: usize,
796        nrhs: usize,
797        h: &CudaSlice<f64>,
798        rhs: &mut CudaSlice<f64>,
799        info_dev: &mut CudaSlice<i32>,
800    ) -> Result<(), String> {
801        let p_i = to_i32(p)?;
802        let nrhs_i = to_i32(nrhs)?;
803        let uplo = cusolver_sys::cublasFillMode_t::CUBLAS_FILL_MODE_LOWER;
804        {
805            let (h_ptr, _h_record) = h.device_ptr(stream);
806            let (rhs_ptr, _rhs_record) = rhs.device_ptr_mut(stream);
807            let (info_ptr, _info_record) = info_dev.device_ptr_mut(stream);
808            // SAFETY: cuSOLVER potrs; h is a p*p Cholesky factor, rhs is p*nrhs,
809            // info_dev is a pre-allocated 1-element i32 device buffer.
810            let status = unsafe {
811                cusolver_sys::cusolverDnDpotrs(
812                    solver.cu(),
813                    uplo,
814                    p_i,
815                    nrhs_i,
816                    h_ptr as *const f64,
817                    p_i,
818                    rhs_ptr as *mut f64,
819                    p_i,
820                    info_ptr as *mut i32,
821                )
822            };
823            check_cusolver(status, "cusolverDnDpotrs")?;
824        }
825        Ok(())
826    }
827
828    /// Download the POTRF deferred info scalar and return an error if non-zero.
829    ///
830    /// Called once at end-of-fit (or whenever the convergence loop exits) to
831    /// surface any factorization failure that was deferred device-side by
832    /// [`potrf_in_place_reuse`].
833    pub fn check_deferred_potrf_info(
834        stream: &std::sync::Arc<cudarc::driver::CudaStream>,
835        info_dev: &CudaSlice<i32>,
836    ) -> Result<(), String> {
837        let info_host = stream
838            .clone_dtoh(info_dev)
839            .map_err(|e| format!("download deferred potrf info: {e}"))?;
840        if info_host[0] == 0 {
841            Ok(())
842        } else {
843            Err(format!(
844                "cusolverDnDpotrf returned info={} (detected at end-of-fit)",
845                info_host[0]
846            ))
847        }
848    }
849
850    /// Download the POTRS deferred info scalar and return an error if non-zero.
851    ///
852    /// Mirrors [`check_deferred_potrf_info`] for the triangular-solve step.
853    pub fn check_deferred_potrs_info(
854        stream: &std::sync::Arc<cudarc::driver::CudaStream>,
855        info_dev: &CudaSlice<i32>,
856    ) -> Result<(), String> {
857        let info_host = stream
858            .clone_dtoh(info_dev)
859            .map_err(|e| format!("download deferred potrs info: {e}"))?;
860        if info_host[0] == 0 {
861            Ok(())
862        } else {
863            Err(format!(
864                "cusolverDnDpotrs returned info={} (detected at end-of-fit)",
865                info_host[0]
866            ))
867        }
868    }
869
870    pub fn cholesky_logdet_from_col_major(factor: &[f64], p: usize) -> f64 {
871        let factor = MatRef::from_column_major_slice(factor, p, p);
872        cholesky_factor_logdet(factor)
873    }
874
875    fn check_cusolver(
876        status: cusolver_sys::cusolverStatus_t,
877        label: &'static str,
878    ) -> Result<(), String> {
879        if status == cusolver_sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
880            Ok(())
881        } else {
882            Err(format!("{label} failed with {status:?}"))
883        }
884    }
885
886    fn to_i32(value: usize) -> Result<i32, String> {
887        i32::try_from(value).map_err(|_| format!("CUDA dimension {value} exceeds i32"))
888    }
889}
890
891// These solver entry points are consumed by sibling crates (`gam-solve`'s
892// pirls/reml GPU paths, `gam-models`, ...) via `gam_gpu::solver::*`, so they
893// are part of gam-gpu's public surface. `potrf_in_place_generic` is the
894// only one with no cross-crate consumer; it stays crate-private and is
895// reached internally through `crate::solver::potrf_in_place_generic`.
896#[cfg(target_os = "linux")]
897pub(crate) use cuda::potrf_in_place_generic;
898#[cfg(target_os = "linux")]
899pub use cuda::{
900    check_deferred_potrf_info, check_deferred_potrs_info, cholesky_logdet_from_col_major,
901    context_and_stream, pinned_htod, potrf_in_place, potrf_in_place_reuse, potrf_query_lwork,
902    potrs_in_place, potrs_in_place_reuse,
903};
904
905/// Solve `A x = b` with fp32 Cholesky factorization + fp64-residual iterative
906/// refinement, automatically falling back to fp64 when the policy rejects the
907/// attempt or when the fp32 path fails / diverges.
908///
909/// The `p` threshold and maximum step count come from [`GpuDispatchPolicy`]
910/// constants — there is no user-facing knob. The decision path is:
911///
912/// 1. `policy.iterative_refinement_should_attempt(p)` → `false` or
913///    multi-column RHS: skip to the fp64 Cholesky path.
914/// 2. Attempt fp32 POTRF + up to `REFINEMENT_MAX_STEPS` residual-correction
915///    steps. Falls back to fp64 on:
916///    - fp32 POTRF info ≠ 0 (A is not SPD at f32 precision),
917///    - non-monotone residual (κ(A)·u_fp32 ≥ 1 regime).
918/// 3. On fp32 success the logdet is computed from the fp64 Cholesky factor —
919///    BUT only when `need_logdet` is true. The fp64 POTRF is an O(p³)
920///    factorization that fully negates the mixed-precision speedup (the whole
921///    point is to do the expensive factor in fp32), so a caller that only needs
922///    the *solution* (e.g. the PIRLS Newton direction solve, which discards the
923///    logdet) passes `need_logdet = false` and the redundant fp64 POTRF is
924///    skipped entirely — the returned logdet is `NaN` in that case. The solution
925///    is always full-fp64-accurate via the residual refinement regardless.
926///
927/// Returns `(solution, logdet, Some(RefinementOutcome))` when the fp32 path
928/// succeeded, or `(solution, logdet, None)` on the fp64 fallback. When
929/// `need_logdet` is false and the fp32 path succeeds, the logdet field is `NaN`.
930pub fn iterative_refinement_cholesky_solve(
931    hessian: ArrayView2<'_, f64>,
932    rhs: ArrayView2<'_, f64>,
933    need_logdet: bool,
934) -> Result<(Array2<f64>, f64, Option<RefinementOutcome>), String> {
935    #[cfg(not(target_os = "linux"))]
936    {
937        let (rows, cols) = hessian.dim();
938        return Err(format!(
939            "CUDA support not compiled; hessian={rows}x{cols}, rhs={}x{}, need_logdet={need_logdet}",
940            rhs.nrows(),
941            rhs.ncols()
942        ));
943    }
944
945    #[cfg(target_os = "linux")]
946    {
947        let runtime = super::device_runtime::GpuRuntime::global().ok_or_else(|| {
948            let (rows, cols) = hessian.dim();
949            format!(
950                "CUDA runtime unavailable; hessian={rows}x{cols}, rhs={}x{}",
951                rhs.nrows(),
952                rhs.ncols()
953            )
954        })?;
955        let p = hessian.nrows();
956
957        // Attempt fp32 + refinement only for single-column RHS with p large
958        // enough that the fp64 GEMV residual cost is amortised.
959        if rhs.ncols() == 1 && runtime.policy.iterative_refinement_should_attempt(p) {
960            let rhs_col = rhs.column(0);
961            let rhs_slice: Vec<f64> = rhs_col.iter().copied().collect();
962            if let Ok(outcome) = cuda::iterative_refinement_solve_impl(hessian, &rhs_slice) {
963                // fp32 + refinement succeeded; the refined solution is full
964                // fp64 accuracy. The logdet, however, needs the fp64 Cholesky
965                // factor (the fp32 diagonal is only fp32-accurate, and the
966                // logdet feeds the REML criterion / EDF). Run the fp64 POTRF
967                // ONLY when the caller actually consumes the logdet: otherwise
968                // that O(p³) factorization is pure overhead that cancels the
969                // mixed-precision win (the expensive factor would then run in
970                // BOTH precisions). A solution-only caller (PIRLS Newton
971                // direction, which discards the logdet) gets the genuine
972                // fp32-factor speedup; logdet is reported as NaN.
973                let mut sol = Array2::<f64>::zeros((p, 1));
974                sol.column_mut(0).assign(&outcome.solution);
975                if !need_logdet {
976                    return Ok((sol, f64::NAN, Some(outcome)));
977                }
978                if let Ok(logdet) = cuda::cholesky_logdet(hessian) {
979                    return Ok((sol, logdet, Some(outcome)));
980                }
981                // fp64 logdet failed (theoretically impossible for SPD A);
982                // fall through to plain fp64 path.
983            }
984            // fp32 path failed (not SPD at f32, or residual non-monotone) →
985            // fall through to fp64.
986        }
987
988        let (sol, logdet) = cuda::cholesky_solve(hessian, rhs)?;
989        Ok((sol, logdet, None))
990    }
991}
992
993pub fn cholesky_solve_gpu(
994    hessian: ArrayView2<'_, f64>,
995    rhs: ArrayView2<'_, f64>,
996) -> Result<(Array2<f64>, f64), String> {
997    // Route through iterative refinement. The function falls back to fp64
998    // internally, so callers always get a valid result; the refinement
999    // outcome metadata is intentionally not surfaced by this thin wrapper.
1000    // This wrapper returns the logdet, so it must request it (`need_logdet`).
1001    let result = iterative_refinement_cholesky_solve(hessian, rhs, /*need_logdet=*/ true)?;
1002    Ok((result.0, result.1))
1003}
1004
1005/// Solution-only mixed-precision solve: like [`cholesky_solve_gpu`] but skips
1006/// the redundant fp64 POTRF when the fp32 + refinement path succeeds, since the
1007/// caller does not consume the log-determinant. This is the path that delivers
1008/// the full mixed-precision speedup (expensive O(p³) factor stays fp32) for the
1009/// PIRLS Newton direction solve, where the logdet is discarded. The solution is
1010/// full fp64 accuracy via iterative refinement.
1011pub fn cholesky_solve_only_gpu(
1012    hessian: ArrayView2<'_, f64>,
1013    rhs: ArrayView2<'_, f64>,
1014) -> Result<Array2<f64>, String> {
1015    let result = iterative_refinement_cholesky_solve(hessian, rhs, /*need_logdet=*/ false)?;
1016    Ok(result.0)
1017}
1018
1019pub fn cholesky_lower_gpu(hessian: ArrayView2<'_, f64>) -> Result<Array2<f64>, String> {
1020    #[cfg(not(target_os = "linux"))]
1021    {
1022        let (rows, cols) = hessian.dim();
1023        return Err(format!(
1024            "CUDA support not compiled for Cholesky factorization; hessian={rows}x{cols}"
1025        ));
1026    }
1027
1028    #[cfg(target_os = "linux")]
1029    {
1030        if super::device_runtime::GpuRuntime::global().is_none() {
1031            let (rows, cols) = hessian.dim();
1032            return Err(format!(
1033                "CUDA runtime unavailable for Cholesky factorization; hessian={rows}x{cols}"
1034            ));
1035        }
1036        cuda::cholesky_lower(hessian)
1037    }
1038}
1039
1040#[cfg(target_os = "linux")]
1041pub(crate) fn cholesky_lower_on_ordinal_gpu(
1042    ordinal: usize,
1043    hessian: ArrayView2<'_, f64>,
1044) -> Result<Array2<f64>, String> {
1045    cuda::cholesky_lower_on_ordinal(ordinal, hessian)
1046}