Skip to main content

gam_gpu/
solver.rs

1//! cuSOLVER-backed dense solver kernels for the GPU HAL.
2//!
3//! This module owns CUDA solver functionality that is shared by GPU linear
4//! algebra dispatch and higher-level solver code. CPU solves do not live behind
5//! these entry points: unavailable CUDA support is reported as an error.
6
7use ndarray::{Array2, ArrayView2};
8
9pub fn solver_backend_status() -> super::CudaBackendStatus {
10    super::cuda_backend_status()
11}
12
13/// Outcome reported by [`iterative_refinement_cholesky_solve`].
14#[derive(Clone, Debug)]
15pub struct RefinementOutcome {
16    /// Solution vector `x` satisfying `A x ≈ b`.
17    pub solution: ndarray::Array1<f64>,
18    /// `‖r‖ / ‖b‖` where `r = b − A x` after the last refinement step
19    /// (or after the initial fp32 solve when no steps were taken).
20    pub relative_residual: f64,
21    /// Precision path used for the factorization.
22    pub used_fp32_factor: bool,
23    /// Number of refinement steps taken (0 means only the initial solve ran).
24    pub refinement_steps: usize,
25}
26
27#[cfg(target_os = "linux")]
28mod cuda {
29    use crate::driver::{from_col_major, to_col_major};
30    use gam_linalg::faer_ndarray::cholesky_factor_logdet;
31    use cudarc::cublas::sys as cublas_sys;
32    use cudarc::cublas::{CudaBlas, Gemv, GemvConfig};
33    use cudarc::cusolver::{DnHandle, sys as cusolver_sys};
34    use cudarc::driver::{CudaContext, CudaSlice, DevicePtr, DevicePtrMut};
35    use faer::MatRef;
36    use ndarray::{Array2, ArrayView2};
37
38    pub(super) fn cholesky_solve(
39        hessian: ArrayView2<'_, f64>,
40        rhs: ArrayView2<'_, f64>,
41    ) -> Result<(Array2<f64>, f64), String> {
42        let (_, stream) = context_and_stream()?;
43        let (p, p2) = hessian.dim();
44        if p == 0 || p != p2 || rhs.nrows() != p {
45            return Err("Cholesky solve dimension mismatch".to_string());
46        }
47        let nrhs = rhs.ncols();
48        let solver = DnHandle::new(stream.clone()).map_err(|e| format!("cusolver init: {e}"))?;
49        let h_col = to_col_major(&hessian);
50        let rhs_col = to_col_major(&rhs);
51        let mut h_dev = pinned_htod(&stream, &h_col)?;
52        let mut rhs_dev = pinned_htod(&stream, &rhs_col)?;
53        potrf_in_place(&solver, &stream, p, &mut h_dev)?;
54        potrs_in_place(&solver, &stream, p, nrhs, &h_dev, &mut rhs_dev)?;
55        let factor_col = stream
56            .clone_dtoh(&h_dev)
57            .map_err(|e| format!("download Cholesky factor: {e}"))?;
58        let out_col = stream
59            .clone_dtoh(&rhs_dev)
60            .map_err(|e| format!("download solution: {e}"))?;
61        let solved =
62            from_col_major(&out_col, p, nrhs).ok_or("solution layout conversion failed")?;
63        Ok((solved, cholesky_logdet_from_col_major(&factor_col, p)))
64    }
65
66    /// fp64 log-determinant of an SPD matrix via POTRF only.
67    ///
68    /// This is [`cholesky_solve`] stripped of the triangular solve (POTRS) and
69    /// the solution download/layout conversion: the log-determinant depends
70    /// solely on the Cholesky factor's diagonal, so when a caller already holds
71    /// the solution (e.g. from fp32 + iterative refinement) and needs *only* an
72    /// accurate fp64 logdet, doing a full solve here would burn an O(p²·nrhs)
73    /// POTRS plus a host round-trip on a solution that is immediately discarded.
74    pub(super) fn cholesky_logdet(hessian: ArrayView2<'_, f64>) -> Result<f64, String> {
75        let (_, stream) = context_and_stream()?;
76        let (p, p2) = hessian.dim();
77        if p == 0 || p != p2 {
78            return Err("Cholesky logdet dimension mismatch".to_string());
79        }
80        let solver = DnHandle::new(stream.clone()).map_err(|e| format!("cusolver init: {e}"))?;
81        let h_col = to_col_major(&hessian);
82        let mut h_dev = pinned_htod(&stream, &h_col)?;
83        potrf_in_place(&solver, &stream, p, &mut h_dev)?;
84        let factor_col = stream
85            .clone_dtoh(&h_dev)
86            .map_err(|e| format!("download Cholesky factor: {e}"))?;
87        Ok(cholesky_logdet_from_col_major(&factor_col, p))
88    }
89
90    pub(super) fn cholesky_lower(hessian: ArrayView2<'_, f64>) -> Result<Array2<f64>, String> {
91        let (_, stream) = context_and_stream()?;
92        cholesky_lower_on_stream(hessian, &stream)
93    }
94
95    pub(super) fn cholesky_lower_on_ordinal(
96        ordinal: usize,
97        hessian: ArrayView2<'_, f64>,
98    ) -> Result<Array2<f64>, String> {
99        let (_, stream) = context_and_stream_for(ordinal)?;
100        cholesky_lower_on_stream(hessian, &stream)
101    }
102
103    fn cholesky_lower_on_stream(
104        hessian: ArrayView2<'_, f64>,
105        stream: &std::sync::Arc<cudarc::driver::CudaStream>,
106    ) -> Result<Array2<f64>, String> {
107        let (p, p2) = hessian.dim();
108        if p == 0 || p != p2 {
109            return Err("Cholesky factorization dimension mismatch".to_string());
110        }
111        let solver = DnHandle::new(stream.clone()).map_err(|e| format!("cusolver init: {e}"))?;
112        let h_col = to_col_major(&hessian);
113        let mut h_dev = pinned_htod(&stream, &h_col)?;
114        potrf_in_place(&solver, &stream, p, &mut h_dev)?;
115        let factor_col = stream
116            .clone_dtoh(&h_dev)
117            .map_err(|e| format!("download Cholesky factor: {e}"))?;
118        let mut lower =
119            from_col_major(&factor_col, p, p).ok_or("factor layout conversion failed")?;
120        for row in 0..p {
121            for col in (row + 1)..p {
122                lower[[row, col]] = 0.0;
123            }
124        }
125        Ok(lower)
126    }
127
128    // -----------------------------------------------------------------------
129    // Precision-generic Cholesky scaffold
130    //
131    // POTRF / POTRS host scaffolds are identical across single and double
132    // precision apart from the cuSOLVER symbol called and the device pointer
133    // type. `CholScalar` selects those per-precision pieces so the host-side
134    // allocation / info-handling / error-formatting logic lives once. The
135    // `Dpotr*` (f64) and `Spotr*` (f32) entry points below are thin wrappers
136    // over the generic helpers, preserving their public signatures byte for
137    // byte.
138    // -----------------------------------------------------------------------
139
140    /// cuSOLVER scalar abstraction: selects the precision-specific POTRF/POTRS
141    /// symbols and the precision tag used in deferred-info error messages.
142    ///
143    /// The FFI into cuSOLVER lives inside the trait methods' bodies (in `unsafe`
144    /// blocks), so the trait and its methods are safe to call: each impl wires
145    /// its method bodies to the cuSOLVER entry points whose pointer arguments
146    /// match `Self` (e.g. `cusolverDnDpotrf` for `f64`). Implementors must keep
147    /// that pairing consistent — the device pointer passed in is typed `*mut
148    /// Self` / `*const Self`, so a mismatched symbol would hand cuSOLVER a
149    /// wrongly-typed buffer.
150    pub(crate) trait CholScalar:
151        cudarc::driver::DeviceRepr + cudarc::driver::ValidAsZeroBits + Copy
152    {
153        /// cuSOLVER `*potrf_bufferSize`: `(handle, uplo, n, A, lda, *lwork)`.
154        ///
155        /// `a` is a live `n*n` column-major device buffer of type `Self`,
156        /// `lwork` is a host out-param. The unsafe FFI call is contained in the
157        /// method body.
158        fn potrf_buffer_size(
159            handle: cusolver_sys::cusolverDnHandle_t,
160            uplo: cusolver_sys::cublasFillMode_t,
161            n: i32,
162            a: *mut Self,
163            lda: i32,
164            lwork: *mut i32,
165        ) -> cusolver_sys::cusolverStatus_t;
166        /// cuSOLVER `*potrf`: `(handle, uplo, n, A, lda, work, lwork, info)`.
167        ///
168        /// Pointer args must reference live device buffers of the documented
169        /// shape; the unsafe FFI call is contained in the method body.
170        fn potrf(
171            handle: cusolver_sys::cusolverDnHandle_t,
172            uplo: cusolver_sys::cublasFillMode_t,
173            n: i32,
174            a: *mut Self,
175            lda: i32,
176            work: *mut Self,
177            lwork: i32,
178            info: *mut i32,
179        ) -> cusolver_sys::cusolverStatus_t;
180        /// cuSOLVER `*potrs`: `(handle, uplo, n, nrhs, A, lda, B, ldb, info)`.
181        ///
182        /// Pointer args must reference live device buffers of the documented
183        /// shape; the unsafe FFI call is contained in the method body.
184        fn potrs(
185            handle: cusolver_sys::cusolverDnHandle_t,
186            uplo: cusolver_sys::cublasFillMode_t,
187            n: i32,
188            nrhs: i32,
189            a: *const Self,
190            lda: i32,
191            b: *mut Self,
192            ldb: i32,
193            info: *mut i32,
194        ) -> cusolver_sys::cusolverStatus_t;
195        /// Symbol name fragment for error messages (e.g. `"Dpotrf"`).
196        const POTRF_NAME: &'static str;
197        const POTRS_NAME: &'static str;
198        /// Trailing clause appended to a POTRF "not SPD" error (e.g.
199        /// `" (matrix not SPD at f32)"`); empty for f64.
200        const POTRF_FAIL_SUFFIX: &'static str;
201    }
202
203    impl CholScalar for f64 {
204        fn potrf_buffer_size(
205            handle: cusolver_sys::cusolverDnHandle_t,
206            uplo: cusolver_sys::cublasFillMode_t,
207            n: i32,
208            a: *mut f64,
209            lda: i32,
210            lwork: *mut i32,
211        ) -> cusolver_sys::cusolverStatus_t {
212            // SAFETY: caller guarantees `a` is a live n*n column-major f64 device
213            // buffer and `lwork` is a valid host out-param; symbol matches f64.
214            unsafe { cusolver_sys::cusolverDnDpotrf_bufferSize(handle, uplo, n, a, lda, lwork) }
215        }
216        fn potrf(
217            handle: cusolver_sys::cusolverDnHandle_t,
218            uplo: cusolver_sys::cublasFillMode_t,
219            n: i32,
220            a: *mut f64,
221            lda: i32,
222            work: *mut f64,
223            lwork: i32,
224            info: *mut i32,
225        ) -> cusolver_sys::cusolverStatus_t {
226            // SAFETY: caller guarantees `a` is a live n*n column-major f64 buffer,
227            // `work` was sized by potrf_buffer_size, `info` is a 1-element i32
228            // device buffer; symbol matches f64.
229            unsafe { cusolver_sys::cusolverDnDpotrf(handle, uplo, n, a, lda, work, lwork, info) }
230        }
231        fn potrs(
232            handle: cusolver_sys::cusolverDnHandle_t,
233            uplo: cusolver_sys::cublasFillMode_t,
234            n: i32,
235            nrhs: i32,
236            a: *const f64,
237            lda: i32,
238            b: *mut f64,
239            ldb: i32,
240            info: *mut i32,
241        ) -> cusolver_sys::cusolverStatus_t {
242            // SAFETY: caller guarantees `a` is a live n*n f64 Cholesky factor,
243            // `b` is n*nrhs column-major f64, `info` is a 1-element i32 device
244            // buffer; symbol matches f64.
245            unsafe { cusolver_sys::cusolverDnDpotrs(handle, uplo, n, nrhs, a, lda, b, ldb, info) }
246        }
247        const POTRF_NAME: &'static str = "Dpotrf";
248        const POTRS_NAME: &'static str = "Dpotrs";
249        const POTRF_FAIL_SUFFIX: &'static str = "";
250    }
251
252    impl CholScalar for f32 {
253        fn potrf_buffer_size(
254            handle: cusolver_sys::cusolverDnHandle_t,
255            uplo: cusolver_sys::cublasFillMode_t,
256            n: i32,
257            a: *mut f32,
258            lda: i32,
259            lwork: *mut i32,
260        ) -> cusolver_sys::cusolverStatus_t {
261            // SAFETY: caller guarantees `a` is a live n*n column-major f32 device
262            // buffer and `lwork` is a valid host out-param; symbol matches f32.
263            unsafe { cusolver_sys::cusolverDnSpotrf_bufferSize(handle, uplo, n, a, lda, lwork) }
264        }
265        fn potrf(
266            handle: cusolver_sys::cusolverDnHandle_t,
267            uplo: cusolver_sys::cublasFillMode_t,
268            n: i32,
269            a: *mut f32,
270            lda: i32,
271            work: *mut f32,
272            lwork: i32,
273            info: *mut i32,
274        ) -> cusolver_sys::cusolverStatus_t {
275            // SAFETY: caller guarantees `a` is a live n*n column-major f32 buffer,
276            // `work` was sized by potrf_buffer_size, `info` is a 1-element i32
277            // device buffer; symbol matches f32.
278            unsafe { cusolver_sys::cusolverDnSpotrf(handle, uplo, n, a, lda, work, lwork, info) }
279        }
280        fn potrs(
281            handle: cusolver_sys::cusolverDnHandle_t,
282            uplo: cusolver_sys::cublasFillMode_t,
283            n: i32,
284            nrhs: i32,
285            a: *const f32,
286            lda: i32,
287            b: *mut f32,
288            ldb: i32,
289            info: *mut i32,
290        ) -> cusolver_sys::cusolverStatus_t {
291            // SAFETY: caller guarantees `a` is a live n*n f32 Cholesky factor,
292            // `b` is n*nrhs column-major f32, `info` is a 1-element i32 device
293            // buffer; symbol matches f32.
294            unsafe { cusolver_sys::cusolverDnSpotrs(handle, uplo, n, nrhs, a, lda, b, ldb, info) }
295        }
296        const POTRF_NAME: &'static str = "Spotrf";
297        const POTRS_NAME: &'static str = "Spotrs";
298        const POTRF_FAIL_SUFFIX: &'static str = " (matrix not SPD at f32)";
299    }
300
301    /// Query the cuSOLVER POTRF workspace size (element count) for a p×p
302    /// matrix at precision `T`. Allocates a temporary p×p dummy buffer for the
303    /// query.
304    fn potrf_bufsize_generic<T: CholScalar>(
305        solver: &DnHandle,
306        stream: &std::sync::Arc<cudarc::driver::CudaStream>,
307        p: usize,
308    ) -> Result<usize, String> {
309        let p_i = to_i32(p)?;
310        let uplo = cusolver_sys::cublasFillMode_t::CUBLAS_FILL_MODE_LOWER;
311        let mut lwork = 0_i32;
312        let mut dummy = stream
313            .alloc_zeros::<T>(p.checked_mul(p).ok_or("p² overflow in lwork query")?)
314            .map_err(|e| format!("cuda alloc dummy for lwork query: {e}"))?;
315        {
316            let (ptr, _rec) = dummy.device_ptr_mut(stream);
317            // dummy is a live p*p device buffer of type T, lwork is a host i32;
318            // the unsafe cuSOLVER FFI is contained in T::potrf_buffer_size.
319            let status =
320                T::potrf_buffer_size(solver.cu(), uplo, p_i, ptr as *mut T, p_i, &mut lwork);
321            check_cusolver(status, "cusolverDn*potrf_bufferSize")?;
322        }
323        usize::try_from(lwork).map_err(|_| "negative potrf lwork".to_string())
324    }
325
326    /// Factor a p×p SPD device buffer in-place (lower-triangular Cholesky) at
327    /// precision `T`, querying and allocating its own workspace. Returns `Err`
328    /// if the matrix is singular/indefinite at precision `T`.
329    ///
330    /// This is the single-matrix POTRF core shared across the GPU layer:
331    /// `solver.rs`'s `potrf_in_place`/`spotrf_in_place` and `linalg.rs`'s
332    /// `potrf_lower_in_place` all route through it (the latter mapping the
333    /// `Result` to its `Option` contract at the boundary). The batched POTRF
334    /// (`cusolverDnDpotrfBatched`) in `linalg.rs` is intentionally separate.
335    pub(crate) fn potrf_in_place_generic<T: CholScalar>(
336        solver: &DnHandle,
337        stream: &std::sync::Arc<cudarc::driver::CudaStream>,
338        p: usize,
339        a: &mut CudaSlice<T>,
340    ) -> Result<(), String> {
341        let p_i = to_i32(p)?;
342        let lwork = potrf_bufsize_generic::<T>(solver, stream, p)?;
343        let lwork_i = i32::try_from(lwork).map_err(|_| "negative potrf workspace".to_string())?;
344        let mut workspace = stream
345            .alloc_zeros::<T>(lwork.max(1))
346            .map_err(|e| format!("cuda alloc potrf workspace: {e}"))?;
347        let mut info = stream
348            .alloc_zeros::<i32>(1)
349            .map_err(|e| format!("cuda alloc potrf info: {e}"))?;
350        let uplo = cusolver_sys::cublasFillMode_t::CUBLAS_FILL_MODE_LOWER;
351        {
352            let (a_ptr, _a_rec) = a.device_ptr_mut(stream);
353            let (work_ptr, _work_rec) = workspace.device_ptr_mut(stream);
354            let (info_ptr, _info_rec) = info.device_ptr_mut(stream);
355            // a is p*p col-major T, workspace was sized by T::potrf_buffer_size,
356            // info is a 1-element i32 device buffer; the unsafe cuSOLVER FFI is
357            // contained in T::potrf.
358            let status = T::potrf(
359                solver.cu(),
360                uplo,
361                p_i,
362                a_ptr as *mut T,
363                p_i,
364                work_ptr as *mut T,
365                lwork_i,
366                info_ptr as *mut i32,
367            );
368            check_cusolver(status, "cusolverDn*potrf")?;
369        }
370        let info_host = stream
371            .clone_dtoh(&info)
372            .map_err(|e| format!("download potrf info: {e}"))?;
373        if info_host[0] == 0 {
374            Ok(())
375        } else {
376            Err(format!(
377                "cusolverDn{} returned info={}{}",
378                T::POTRF_NAME,
379                info_host[0],
380                T::POTRF_FAIL_SUFFIX
381            ))
382        }
383    }
384
385    /// Triangular solve using a pre-factored Cholesky lower-triangle at
386    /// precision `T`. Solves `A · x = rhs` in-place into `rhs` (column-major,
387    /// p × nrhs), allocating and downloading its own info scalar.
388    fn potrs_in_place_generic<T: CholScalar>(
389        solver: &DnHandle,
390        stream: &std::sync::Arc<cudarc::driver::CudaStream>,
391        p: usize,
392        nrhs: usize,
393        factor: &CudaSlice<T>,
394        rhs: &mut CudaSlice<T>,
395    ) -> Result<(), String> {
396        let p_i = to_i32(p)?;
397        let nrhs_i = to_i32(nrhs)?;
398        let uplo = cusolver_sys::cublasFillMode_t::CUBLAS_FILL_MODE_LOWER;
399        let mut info = stream
400            .alloc_zeros::<i32>(1)
401            .map_err(|e| format!("cuda alloc potrs info: {e}"))?;
402        {
403            let (f_ptr, _f_rec) = factor.device_ptr(stream);
404            let (r_ptr, _r_rec) = rhs.device_ptr_mut(stream);
405            let (info_ptr, _info_rec) = info.device_ptr_mut(stream);
406            // factor is a p*p lower-triangular T from potrf, rhs is p*nrhs
407            // col-major T, info is a 1-element i32 device buffer; leading dims
408            // match column-major p_i. The unsafe cuSOLVER FFI is contained in
409            // T::potrs.
410            let status = T::potrs(
411                solver.cu(),
412                uplo,
413                p_i,
414                nrhs_i,
415                f_ptr as *const T,
416                p_i,
417                r_ptr as *mut T,
418                p_i,
419                info_ptr as *mut i32,
420            );
421            check_cusolver(status, "cusolverDn*potrs")?;
422        }
423        let info_host = stream
424            .clone_dtoh(&info)
425            .map_err(|e| format!("download potrs info: {e}"))?;
426        if info_host[0] == 0 {
427            Ok(())
428        } else {
429            Err(format!(
430                "cusolverDn{} returned info={}",
431                T::POTRS_NAME,
432                info_host[0]
433            ))
434        }
435    }
436
437    // -----------------------------------------------------------------------
438    // fp32 entry points (thin wrappers over the precision-generic scaffold)
439    // -----------------------------------------------------------------------
440
441    /// Factor a p×p symmetric positive-definite f32 device buffer in-place
442    /// (lower-triangular Cholesky). Returns `Err` if the matrix is
443    /// singular/indefinite.
444    fn spotrf_in_place(
445        solver: &DnHandle,
446        stream: &std::sync::Arc<cudarc::driver::CudaStream>,
447        p: usize,
448        a: &mut CudaSlice<f32>,
449    ) -> Result<(), String> {
450        potrf_in_place_generic::<f32>(solver, stream, p, a)
451    }
452
453    /// Triangular solve using a pre-factored fp32 Cholesky lower-triangle.
454    /// Solves `A · x = rhs` in-place into `rhs` (column-major, p × nrhs).
455    fn spotrs_in_place(
456        solver: &DnHandle,
457        stream: &std::sync::Arc<cudarc::driver::CudaStream>,
458        p: usize,
459        nrhs: usize,
460        factor: &CudaSlice<f32>,
461        rhs: &mut CudaSlice<f32>,
462    ) -> Result<(), String> {
463        potrs_in_place_generic::<f32>(solver, stream, p, nrhs, factor, rhs)
464    }
465
466    // -----------------------------------------------------------------------
467    // fp64 DGEMV residual: r = b − A·x in double precision
468    // -----------------------------------------------------------------------
469
470    /// Compute `r = b − A·x` in fp64 where A is p×p and x, b, r are length p.
471    ///
472    /// Overwrites the output buffer `r_dev` with the residual. Uses
473    /// `cublasDgemv` (CUBLAS_OP_N): `r = 1·A·x + 0·0 = A·x`, then the host
474    /// subtracts from b. Because p is small here (the policy gates on p ≥ 64
475    /// and the Newton system is p×p), downloading the p-vector for the host
476    /// subtract is cheap relative to the GEMV.
477    fn residual_norm_and_vec(
478        blas: &CudaBlas,
479        stream: &std::sync::Arc<cudarc::driver::CudaStream>,
480        p: usize,
481        a_dev: &CudaSlice<f64>,
482        x_dev: &CudaSlice<f64>,
483        b_host: &[f64],
484    ) -> Result<(Vec<f64>, f64), String> {
485        let p_i = to_i32(p)?;
486        // ax_dev = A · x
487        let mut ax_dev = stream
488            .alloc_zeros::<f64>(p)
489            .map_err(|e| format!("alloc ax: {e}"))?;
490        {
491            let cfg = GemvConfig::<f64> {
492                trans: cublas_sys::cublasOperation_t::CUBLAS_OP_N,
493                m: p_i,
494                n: p_i,
495                alpha: 1.0_f64,
496                lda: p_i,
497                incx: 1,
498                beta: 0.0_f64,
499                incy: 1,
500            };
501            // SAFETY: cuBLAS Dgemv; a_dev is p*p col-major f64, x_dev is
502            // length-p f64, ax_dev is length-p output; all on the same stream.
503            unsafe { blas.gemv(cfg, a_dev, x_dev, &mut ax_dev) }
504                .map_err(|e| format!("cublasDgemv for residual: {e}"))?;
505        }
506        let ax_host = stream
507            .clone_dtoh(&ax_dev)
508            .map_err(|e| format!("download A·x: {e}"))?;
509        // r = b − A·x  (host subtract; p is small)
510        let r: Vec<f64> = b_host
511            .iter()
512            .zip(ax_host.iter())
513            .map(|(bi, axi)| bi - axi)
514            .collect();
515        let norm_r = r.iter().map(|v| v * v).sum::<f64>().sqrt();
516        Ok((r, norm_r))
517    }
518
519    // -----------------------------------------------------------------------
520    // Iterative refinement: fp32 factor → fp32 solve → fp64 residual loop
521    // -----------------------------------------------------------------------
522
523    /// Solve `A x = b` using an fp32 Cholesky factorization with up to
524    /// `max_steps` fp64-residual iterative refinement corrections.
525    ///
526    /// # Algorithm
527    ///
528    /// 1. Cast `A` (f64) to f32 on device. Factor in fp32 (POTRF).
529    /// 2. Cast `b` (f64) to f32. Solve `A x = b` in fp32 (POTRS). Lift `x`
530    ///    to f64.
531    /// 3. Loop up to `max_steps`:
532    ///    a. `r = b − A·x` accumulated in fp64 (cuBLAS Dgemv).
533    ///    b. `‖r‖ / ‖b‖ ≤ tol` → converged, break.
534    ///    c. Residual did not drop below previous step → bail, return `Err`.
535    ///    d. Cast `r` to f32. Solve `A e = r` in fp32. `x += e` (f64).
536    /// 4. Return `(x, ‖r‖/‖b‖, refinement_steps)`.
537    ///
538    /// Returns `Err` when the fp32 POTRF fails (not SPD at f32) or when the
539    /// residual does not decrease monotonically (κ(A)·u_f32 ≥ 1 regime).
540    /// Callers should fall back to fp64 POTRF on `Err`.
541    pub(super) fn iterative_refinement_solve_impl(
542        hessian: ArrayView2<'_, f64>,
543        rhs: &[f64],
544    ) -> Result<super::RefinementOutcome, String> {
545        use crate::policy::GpuDispatchPolicy;
546        let (p, p2) = hessian.dim();
547        if p == 0 || p != p2 || rhs.len() != p {
548            return Err("iterative_refinement_solve: dimension mismatch".to_string());
549        }
550        let max_steps = GpuDispatchPolicy::REFINEMENT_MAX_STEPS;
551        let tol = GpuDispatchPolicy::REFINEMENT_TOL;
552
553        let (_, stream) = context_and_stream()?;
554        let solver = DnHandle::new(stream.clone()).map_err(|e| format!("cusolver init: {e}"))?;
555        let blas = CudaBlas::new(stream.clone()).map_err(|e| format!("cublas init: {e}"))?;
556
557        // Upload fp64 hessian for residual GEMV.
558        let h_col_f64 = to_col_major(&hessian);
559        let a_dev_f64 = pinned_htod(&stream, &h_col_f64)?;
560
561        // Cast A to f32 and upload.
562        let h_col_f32: Vec<f32> = h_col_f64.iter().map(|&v| v as f32).collect();
563        let mut a_dev_f32 =
564            pinned_htod(&stream, &h_col_f32).map_err(|e| format!("upload f32 A: {e}"))?;
565
566        // fp32 POTRF — returns Err if A is not SPD at f32 precision.
567        spotrf_in_place(&solver, &stream, p, &mut a_dev_f32)?;
568
569        // Cast b to f32 and upload; solve in fp32.
570        let b_f32: Vec<f32> = rhs.iter().map(|&v| v as f32).collect();
571        let mut x_dev_f32 =
572            pinned_htod(&stream, &b_f32).map_err(|e| format!("upload f32 rhs: {e}"))?;
573        spotrs_in_place(&solver, &stream, p, 1, &a_dev_f32, &mut x_dev_f32)?;
574
575        // Lift x to f64.
576        let x_f32 = stream
577            .clone_dtoh(&x_dev_f32)
578            .map_err(|e| format!("download f32 x: {e}"))?;
579        let mut x: Vec<f64> = x_f32.iter().map(|&v| v as f64).collect();
580
581        // Compute ‖b‖ for relative residual.
582        let norm_b = rhs.iter().map(|v| v * v).sum::<f64>().sqrt();
583        let norm_b_safe = if norm_b > 0.0 { norm_b } else { 1.0 };
584
585        let mut x_dev_f64 = pinned_htod(&stream, &x).map_err(|e| format!("upload f64 x: {e}"))?;
586        let (r0, norm_r0) = residual_norm_and_vec(&blas, &stream, p, &a_dev_f64, &x_dev_f64, rhs)?;
587        let mut rel_residual = norm_r0 / norm_b_safe;
588
589        // Early exit: already converged after initial solve.
590        if rel_residual <= tol {
591            return Ok(super::RefinementOutcome {
592                solution: ndarray::Array1::from_vec(x),
593                relative_residual: rel_residual,
594                used_fp32_factor: true,
595                refinement_steps: 0,
596            });
597        }
598
599        let mut r = r0;
600        let mut prev_norm_r = norm_r0;
601        let mut steps_taken = 0_usize;
602
603        for _ in 0..max_steps {
604            // Cast residual to f32, solve A e = r in fp32.
605            let r_f32: Vec<f32> = r.iter().map(|&v| v as f32).collect();
606            let mut e_dev_f32 =
607                pinned_htod(&stream, &r_f32).map_err(|e| format!("upload f32 residual: {e}"))?;
608            spotrs_in_place(&solver, &stream, p, 1, &a_dev_f32, &mut e_dev_f32)?;
609
610            // x += e in f64.
611            let e_f32 = stream
612                .clone_dtoh(&e_dev_f32)
613                .map_err(|e| format!("download f32 e: {e}"))?;
614            for (xi, ei) in x.iter_mut().zip(e_f32.iter()) {
615                *xi += *ei as f64;
616            }
617            steps_taken += 1;
618
619            // Reupload x_dev_f64 and compute new residual.
620            x_dev_f64 = pinned_htod(&stream, &x).map_err(|e| format!("upload refined x: {e}"))?;
621            let (r_new, norm_r_new) =
622                residual_norm_and_vec(&blas, &stream, p, &a_dev_f64, &x_dev_f64, rhs)?;
623            rel_residual = norm_r_new / norm_b_safe;
624
625            // Check monotone decrease. Non-monotone → κ(A)·u ≥ 1.
626            if norm_r_new >= prev_norm_r {
627                return Err(format!(
628                    "iterative refinement: residual not decreasing ({norm_r_new:.3e} ≥ {prev_norm_r:.3e}); \
629                     κ(A)·u_f32 ≥ 1, cannot refine"
630                ));
631            }
632            prev_norm_r = norm_r_new;
633            r = r_new;
634
635            if rel_residual <= tol {
636                break;
637            }
638        }
639
640        Ok(super::RefinementOutcome {
641            solution: ndarray::Array1::from_vec(x),
642            relative_residual: rel_residual,
643            used_fp32_factor: true,
644            refinement_steps: steps_taken,
645        })
646    }
647
648    /// Bind a specific device ordinal's cached context on the calling thread and
649    /// open a fresh stream on it. This is the per-ordinal entry point used by
650    /// multi-GPU fan-out (`crate::pool::scatter_batched` workers) so a
651    /// Cholesky / TRSM can target the device the worker thread owns. The
652    /// primary-device convenience wrapper [`context_and_stream`] calls this with
653    /// the probe-selected ordinal.
654    pub(crate) fn context_and_stream_for(
655        ordinal: usize,
656    ) -> Result<
657        (
658            std::sync::Arc<CudaContext>,
659            std::sync::Arc<cudarc::driver::CudaStream>,
660        ),
661        String,
662    > {
663        let ctx = super::super::device_runtime::cuda_context_for(ordinal)
664            .ok_or_else(|| format!("cuda context for ordinal {ordinal} unavailable"))?;
665        ctx.bind_to_thread()
666            .map_err(|e| format!("cuda context bind_to_thread: {e}"))?;
667        let stream = ctx.new_stream().map_err(|e| format!("cuda stream: {e}"))?;
668        Ok((ctx, stream))
669    }
670
671    pub fn context_and_stream() -> Result<
672        (
673            std::sync::Arc<CudaContext>,
674            std::sync::Arc<cudarc::driver::CudaStream>,
675        ),
676        String,
677    > {
678        // Route through the runtime's cached primary context for the selected
679        // device so every CUDA client in the process (calibration, session,
680        // cuSolver) shares one CUcontext per ordinal. Falling back to
681        // `CudaContext::new(0)` here would fragment driver state across
682        // distinct contexts, defeat memory-pool sharing, and pin work to
683        // ordinal 0 even when the runtime probe chose a different device.
684        let runtime = super::super::device_runtime::GpuRuntime::global()
685            .ok_or_else(|| "cuda runtime unavailable".to_string())?;
686        context_and_stream_for(runtime.selected_device().ordinal)
687    }
688
689    pub fn pinned_htod<
690        T: cudarc::driver::DeviceRepr + cudarc::driver::ValidAsZeroBits + Copy,
691    >(
692        stream: &std::sync::Arc<cudarc::driver::CudaStream>,
693        src: &[T],
694    ) -> Result<CudaSlice<T>, String> {
695        // Originally this routine round-tripped the upload through a
696        // `CU_MEMHOSTALLOC_WRITECOMBINED` pinned staging buffer
697        // (`ctx.alloc_pinned`) to enable async DMA. In cudarc 0.19 the
698        // `PinnedHostSlice` returned from `alloc_pinned` carries an event that
699        // its `Drop` impl unconditionally `event.synchronize()`s before freeing
700        // the host mapping — see cudarc-0.19.7 `core.rs::PinnedHostSlice::drop`.
701        // Because the staging buffer goes out of scope at the end of this
702        // function, the host thread blocks here until the H2D copy completes,
703        // immediately defeating the "async" of pinned DMA. The net cost is two
704        // extra driver calls per upload (`cuMemHostAlloc_WC` + `cuMemFreeHost`)
705        // plus a forced stream synchronization, and the workspace ends up
706        // strictly slower than a plain pageable H2D — the driver already
707        // stages pageable copies internally via its own pinned pool, and that
708        // path does not block the issuing host thread for unrelated stream
709        // work. Issue a direct async H2D from the pageable buffer instead.
710        stream.clone_htod(src).map_err(|e| format!("cuda H2D: {e}"))
711    }
712
713    pub fn potrf_in_place(
714        solver: &DnHandle,
715        stream: &std::sync::Arc<cudarc::driver::CudaStream>,
716        p: usize,
717        h: &mut CudaSlice<f64>,
718    ) -> Result<(), String> {
719        potrf_in_place_generic::<f64>(solver, stream, p, h)
720    }
721
722    pub fn potrs_in_place(
723        solver: &DnHandle,
724        stream: &std::sync::Arc<cudarc::driver::CudaStream>,
725        p: usize,
726        nrhs: usize,
727        h: &CudaSlice<f64>,
728        rhs: &mut CudaSlice<f64>,
729    ) -> Result<(), String> {
730        potrs_in_place_generic::<f64>(solver, stream, p, nrhs, h, rhs)
731    }
732
733    /// Query the cuSOLVER POTRF workspace size for a p×p matrix.
734    ///
735    /// Called once at workspace construction to size the persistent workspace
736    /// buffer. Returns the number of f64 elements required.
737    pub fn potrf_query_lwork(
738        solver: &DnHandle,
739        stream: &std::sync::Arc<cudarc::driver::CudaStream>,
740        p: usize,
741    ) -> Result<usize, String> {
742        potrf_bufsize_generic::<f64>(solver, stream, p)
743    }
744
745    /// POTRF factorization using pre-allocated workspace and info buffers.
746    ///
747    /// Does not allocate, does not download `info`. The caller is responsible
748    /// for calling [`check_deferred_potrf_info`] at end-of-fit to confirm no
749    /// factorization failed.
750    ///
751    /// `workspace` must have been allocated with at least `lwork` elements
752    /// (as reported by [`potrf_query_lwork`] at workspace construction).
753    /// `info_dev` is a 1-element device i32 buffer; after a failed
754    /// factorization it holds a positive integer but stays device-resident.
755    pub fn potrf_in_place_reuse(
756        solver: &DnHandle,
757        stream: &std::sync::Arc<cudarc::driver::CudaStream>,
758        p: usize,
759        lwork: i32,
760        h: &mut CudaSlice<f64>,
761        workspace: &mut CudaSlice<f64>,
762        info_dev: &mut CudaSlice<i32>,
763    ) -> Result<(), String> {
764        let p_i = to_i32(p)?;
765        let uplo = cusolver_sys::cublasFillMode_t::CUBLAS_FILL_MODE_LOWER;
766        {
767            let (h_ptr, _h_record) = h.device_ptr_mut(stream);
768            let (work_ptr, _work_record) = workspace.device_ptr_mut(stream);
769            let (info_ptr, _info_record) = info_dev.device_ptr_mut(stream);
770            // SAFETY: cuSOLVER potrf; h is p*p col-major, workspace was sized
771            // by potrf_query_lwork, info_dev is a pre-allocated 1-element i32
772            // device buffer. All buffers are live on the same stream.
773            let status = unsafe {
774                cusolver_sys::cusolverDnDpotrf(
775                    solver.cu(),
776                    uplo,
777                    p_i,
778                    h_ptr as *mut f64,
779                    p_i,
780                    work_ptr as *mut f64,
781                    lwork,
782                    info_ptr as *mut i32,
783                )
784            };
785            check_cusolver(status, "cusolverDnDpotrf")?;
786        }
787        Ok(())
788    }
789
790    /// POTRS triangular solve using a pre-allocated info buffer.
791    ///
792    /// Does not allocate, does not download `info`. The caller is responsible
793    /// for calling [`check_deferred_potrs_info`] at end-of-fit.
794    pub fn potrs_in_place_reuse(
795        solver: &DnHandle,
796        stream: &std::sync::Arc<cudarc::driver::CudaStream>,
797        p: usize,
798        nrhs: usize,
799        h: &CudaSlice<f64>,
800        rhs: &mut CudaSlice<f64>,
801        info_dev: &mut CudaSlice<i32>,
802    ) -> Result<(), String> {
803        let p_i = to_i32(p)?;
804        let nrhs_i = to_i32(nrhs)?;
805        let uplo = cusolver_sys::cublasFillMode_t::CUBLAS_FILL_MODE_LOWER;
806        {
807            let (h_ptr, _h_record) = h.device_ptr(stream);
808            let (rhs_ptr, _rhs_record) = rhs.device_ptr_mut(stream);
809            let (info_ptr, _info_record) = info_dev.device_ptr_mut(stream);
810            // SAFETY: cuSOLVER potrs; h is a p*p Cholesky factor, rhs is p*nrhs,
811            // info_dev is a pre-allocated 1-element i32 device buffer.
812            let status = unsafe {
813                cusolver_sys::cusolverDnDpotrs(
814                    solver.cu(),
815                    uplo,
816                    p_i,
817                    nrhs_i,
818                    h_ptr as *const f64,
819                    p_i,
820                    rhs_ptr as *mut f64,
821                    p_i,
822                    info_ptr as *mut i32,
823                )
824            };
825            check_cusolver(status, "cusolverDnDpotrs")?;
826        }
827        Ok(())
828    }
829
830    /// Download the POTRF deferred info scalar and return an error if non-zero.
831    ///
832    /// Called once at end-of-fit (or whenever the convergence loop exits) to
833    /// surface any factorization failure that was deferred device-side by
834    /// [`potrf_in_place_reuse`].
835    pub fn check_deferred_potrf_info(
836        stream: &std::sync::Arc<cudarc::driver::CudaStream>,
837        info_dev: &CudaSlice<i32>,
838    ) -> Result<(), String> {
839        let info_host = stream
840            .clone_dtoh(info_dev)
841            .map_err(|e| format!("download deferred potrf info: {e}"))?;
842        if info_host[0] == 0 {
843            Ok(())
844        } else {
845            Err(format!(
846                "cusolverDnDpotrf returned info={} (detected at end-of-fit)",
847                info_host[0]
848            ))
849        }
850    }
851
852    /// Download the POTRS deferred info scalar and return an error if non-zero.
853    ///
854    /// Mirrors [`check_deferred_potrf_info`] for the triangular-solve step.
855    pub fn check_deferred_potrs_info(
856        stream: &std::sync::Arc<cudarc::driver::CudaStream>,
857        info_dev: &CudaSlice<i32>,
858    ) -> Result<(), String> {
859        let info_host = stream
860            .clone_dtoh(info_dev)
861            .map_err(|e| format!("download deferred potrs info: {e}"))?;
862        if info_host[0] == 0 {
863            Ok(())
864        } else {
865            Err(format!(
866                "cusolverDnDpotrs returned info={} (detected at end-of-fit)",
867                info_host[0]
868            ))
869        }
870    }
871
872    pub fn cholesky_logdet_from_col_major(factor: &[f64], p: usize) -> f64 {
873        let factor = MatRef::from_column_major_slice(factor, p, p);
874        cholesky_factor_logdet(factor)
875    }
876
877    fn check_cusolver(
878        status: cusolver_sys::cusolverStatus_t,
879        label: &'static str,
880    ) -> Result<(), String> {
881        if status == cusolver_sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
882            Ok(())
883        } else {
884            Err(format!("{label} failed with {status:?}"))
885        }
886    }
887
888    fn to_i32(value: usize) -> Result<i32, String> {
889        i32::try_from(value).map_err(|_| format!("CUDA dimension {value} exceeds i32"))
890    }
891}
892
893// These solver entry points are consumed by sibling crates (`gam-solve`'s
894// pirls/reml GPU paths, `gam-models`, ...) via `gam_gpu::solver::*`, so they
895// are part of gam-gpu's public surface. `potrf_in_place_generic` is the
896// only one with no cross-crate consumer; it stays crate-private and is
897// reached internally through `crate::solver::potrf_in_place_generic`.
898#[cfg(target_os = "linux")]
899pub use cuda::{
900    check_deferred_potrf_info, check_deferred_potrs_info, cholesky_logdet_from_col_major,
901    context_and_stream, pinned_htod, potrf_in_place, potrf_in_place_reuse, potrf_query_lwork,
902    potrs_in_place, potrs_in_place_reuse,
903};
904#[cfg(target_os = "linux")]
905pub(crate) use cuda::potrf_in_place_generic;
906
907/// Solve `A x = b` with fp32 Cholesky factorization + fp64-residual iterative
908/// refinement, automatically falling back to fp64 when the policy rejects the
909/// attempt or when the fp32 path fails / diverges.
910///
911/// The `p` threshold and maximum step count come from [`GpuDispatchPolicy`]
912/// constants — there is no user-facing knob. The decision path is:
913///
914/// 1. `policy.iterative_refinement_should_attempt(p)` → `false` or
915///    multi-column RHS: skip to the fp64 Cholesky path.
916/// 2. Attempt fp32 POTRF + up to `REFINEMENT_MAX_STEPS` residual-correction
917///    steps. Falls back to fp64 on:
918///    - fp32 POTRF info ≠ 0 (A is not SPD at f32 precision),
919///    - non-monotone residual (κ(A)·u_fp32 ≥ 1 regime).
920/// 3. On fp32 success the logdet is computed from the fp64 Cholesky factor —
921///    BUT only when `need_logdet` is true. The fp64 POTRF is an O(p³)
922///    factorization that fully negates the mixed-precision speedup (the whole
923///    point is to do the expensive factor in fp32), so a caller that only needs
924///    the *solution* (e.g. the PIRLS Newton direction solve, which discards the
925///    logdet) passes `need_logdet = false` and the redundant fp64 POTRF is
926///    skipped entirely — the returned logdet is `NaN` in that case. The solution
927///    is always full-fp64-accurate via the residual refinement regardless.
928///
929/// Returns `(solution, logdet, Some(RefinementOutcome))` when the fp32 path
930/// succeeded, or `(solution, logdet, None)` on the fp64 fallback. When
931/// `need_logdet` is false and the fp32 path succeeds, the logdet field is `NaN`.
932pub fn iterative_refinement_cholesky_solve(
933    hessian: ArrayView2<'_, f64>,
934    rhs: ArrayView2<'_, f64>,
935    need_logdet: bool,
936) -> Result<(Array2<f64>, f64, Option<RefinementOutcome>), String> {
937    #[cfg(not(target_os = "linux"))]
938    {
939        let (rows, cols) = hessian.dim();
940        return Err(format!(
941            "CUDA support not compiled; hessian={rows}x{cols}, rhs={}x{}, need_logdet={need_logdet}",
942            rhs.nrows(),
943            rhs.ncols()
944        ));
945    }
946
947    #[cfg(target_os = "linux")]
948    {
949        let runtime = super::device_runtime::GpuRuntime::global().ok_or_else(|| {
950            let (rows, cols) = hessian.dim();
951            format!(
952                "CUDA runtime unavailable; hessian={rows}x{cols}, rhs={}x{}",
953                rhs.nrows(),
954                rhs.ncols()
955            )
956        })?;
957        let p = hessian.nrows();
958
959        // Attempt fp32 + refinement only for single-column RHS with p large
960        // enough that the fp64 GEMV residual cost is amortised.
961        if rhs.ncols() == 1 && runtime.policy.iterative_refinement_should_attempt(p) {
962            let rhs_col = rhs.column(0);
963            let rhs_slice: Vec<f64> = rhs_col.iter().copied().collect();
964            if let Ok(outcome) = cuda::iterative_refinement_solve_impl(hessian, &rhs_slice) {
965                // fp32 + refinement succeeded; the refined solution is full
966                // fp64 accuracy. The logdet, however, needs the fp64 Cholesky
967                // factor (the fp32 diagonal is only fp32-accurate, and the
968                // logdet feeds the REML criterion / EDF). Run the fp64 POTRF
969                // ONLY when the caller actually consumes the logdet: otherwise
970                // that O(p³) factorization is pure overhead that cancels the
971                // mixed-precision win (the expensive factor would then run in
972                // BOTH precisions). A solution-only caller (PIRLS Newton
973                // direction, which discards the logdet) gets the genuine
974                // fp32-factor speedup; logdet is reported as NaN.
975                let mut sol = Array2::<f64>::zeros((p, 1));
976                sol.column_mut(0).assign(&outcome.solution);
977                if !need_logdet {
978                    return Ok((sol, f64::NAN, Some(outcome)));
979                }
980                if let Ok(logdet) = cuda::cholesky_logdet(hessian) {
981                    return Ok((sol, logdet, Some(outcome)));
982                }
983                // fp64 logdet failed (theoretically impossible for SPD A);
984                // fall through to plain fp64 path.
985            }
986            // fp32 path failed (not SPD at f32, or residual non-monotone) →
987            // fall through to fp64.
988        }
989
990        let (sol, logdet) = cuda::cholesky_solve(hessian, rhs)?;
991        Ok((sol, logdet, None))
992    }
993}
994
995pub fn cholesky_solve_gpu(
996    hessian: ArrayView2<'_, f64>,
997    rhs: ArrayView2<'_, f64>,
998) -> Result<(Array2<f64>, f64), String> {
999    // Route through iterative refinement. The function falls back to fp64
1000    // internally, so callers always get a valid result; the refinement
1001    // outcome metadata is intentionally not surfaced by this thin wrapper.
1002    // This wrapper returns the logdet, so it must request it (`need_logdet`).
1003    let result = iterative_refinement_cholesky_solve(hessian, rhs, /*need_logdet=*/ true)?;
1004    Ok((result.0, result.1))
1005}
1006
1007/// Solution-only mixed-precision solve: like [`cholesky_solve_gpu`] but skips
1008/// the redundant fp64 POTRF when the fp32 + refinement path succeeds, since the
1009/// caller does not consume the log-determinant. This is the path that delivers
1010/// the full mixed-precision speedup (expensive O(p³) factor stays fp32) for the
1011/// PIRLS Newton direction solve, where the logdet is discarded. The solution is
1012/// full fp64 accuracy via iterative refinement.
1013pub fn cholesky_solve_only_gpu(
1014    hessian: ArrayView2<'_, f64>,
1015    rhs: ArrayView2<'_, f64>,
1016) -> Result<Array2<f64>, String> {
1017    let result = iterative_refinement_cholesky_solve(hessian, rhs, /*need_logdet=*/ false)?;
1018    Ok(result.0)
1019}
1020
1021pub fn cholesky_lower_gpu(hessian: ArrayView2<'_, f64>) -> Result<Array2<f64>, String> {
1022    #[cfg(not(target_os = "linux"))]
1023    {
1024        let (rows, cols) = hessian.dim();
1025        return Err(format!(
1026            "CUDA support not compiled for Cholesky factorization; hessian={rows}x{cols}"
1027        ));
1028    }
1029
1030    #[cfg(target_os = "linux")]
1031    {
1032        if super::device_runtime::GpuRuntime::global().is_none() {
1033            let (rows, cols) = hessian.dim();
1034            return Err(format!(
1035                "CUDA runtime unavailable for Cholesky factorization; hessian={rows}x{cols}"
1036            ));
1037        }
1038        cuda::cholesky_lower(hessian)
1039    }
1040}
1041
1042#[cfg(target_os = "linux")]
1043pub(crate) fn cholesky_lower_on_ordinal_gpu(
1044    ordinal: usize,
1045    hessian: ArrayView2<'_, f64>,
1046) -> Result<Array2<f64>, String> {
1047    cuda::cholesky_lower_on_ordinal(ordinal, hessian)
1048}