gam_gpu/solver.rs
1//! cuSOLVER-backed dense solver kernels for the GPU HAL.
2//!
3//! This module owns CUDA solver functionality that is shared by GPU linear
4//! algebra dispatch and higher-level solver code. CPU solves do not live behind
5//! these entry points: unavailable CUDA support is reported as an error.
6
7use ndarray::{Array2, ArrayView2};
8
9pub fn solver_backend_status() -> super::CudaBackendStatus {
10 super::cuda_backend_status()
11}
12
13/// Outcome reported by [`iterative_refinement_cholesky_solve`].
14#[derive(Clone, Debug)]
15pub struct RefinementOutcome {
16 /// Solution vector `x` satisfying `A x ≈ b`.
17 pub solution: ndarray::Array1<f64>,
18 /// `‖r‖ / ‖b‖` where `r = b − A x` after the last refinement step
19 /// (or after the initial fp32 solve when no steps were taken).
20 pub relative_residual: f64,
21 /// Precision path used for the factorization.
22 pub used_fp32_factor: bool,
23 /// Number of refinement steps taken (0 means only the initial solve ran).
24 pub refinement_steps: usize,
25}
26
27#[cfg(target_os = "linux")]
28mod cuda {
29 use crate::driver::{from_col_major, to_col_major};
30 use gam_linalg::faer_ndarray::cholesky_factor_logdet;
31 use cudarc::cublas::sys as cublas_sys;
32 use cudarc::cublas::{CudaBlas, Gemv, GemvConfig};
33 use cudarc::cusolver::{DnHandle, sys as cusolver_sys};
34 use cudarc::driver::{CudaContext, CudaSlice, DevicePtr, DevicePtrMut};
35 use faer::MatRef;
36 use ndarray::{Array2, ArrayView2};
37
38 pub(super) fn cholesky_solve(
39 hessian: ArrayView2<'_, f64>,
40 rhs: ArrayView2<'_, f64>,
41 ) -> Result<(Array2<f64>, f64), String> {
42 let (_, stream) = context_and_stream()?;
43 let (p, p2) = hessian.dim();
44 if p == 0 || p != p2 || rhs.nrows() != p {
45 return Err("Cholesky solve dimension mismatch".to_string());
46 }
47 let nrhs = rhs.ncols();
48 let solver = DnHandle::new(stream.clone()).map_err(|e| format!("cusolver init: {e}"))?;
49 let h_col = to_col_major(&hessian);
50 let rhs_col = to_col_major(&rhs);
51 let mut h_dev = pinned_htod(&stream, &h_col)?;
52 let mut rhs_dev = pinned_htod(&stream, &rhs_col)?;
53 potrf_in_place(&solver, &stream, p, &mut h_dev)?;
54 potrs_in_place(&solver, &stream, p, nrhs, &h_dev, &mut rhs_dev)?;
55 let factor_col = stream
56 .clone_dtoh(&h_dev)
57 .map_err(|e| format!("download Cholesky factor: {e}"))?;
58 let out_col = stream
59 .clone_dtoh(&rhs_dev)
60 .map_err(|e| format!("download solution: {e}"))?;
61 let solved =
62 from_col_major(&out_col, p, nrhs).ok_or("solution layout conversion failed")?;
63 Ok((solved, cholesky_logdet_from_col_major(&factor_col, p)))
64 }
65
66 /// fp64 log-determinant of an SPD matrix via POTRF only.
67 ///
68 /// This is [`cholesky_solve`] stripped of the triangular solve (POTRS) and
69 /// the solution download/layout conversion: the log-determinant depends
70 /// solely on the Cholesky factor's diagonal, so when a caller already holds
71 /// the solution (e.g. from fp32 + iterative refinement) and needs *only* an
72 /// accurate fp64 logdet, doing a full solve here would burn an O(p²·nrhs)
73 /// POTRS plus a host round-trip on a solution that is immediately discarded.
74 pub(super) fn cholesky_logdet(hessian: ArrayView2<'_, f64>) -> Result<f64, String> {
75 let (_, stream) = context_and_stream()?;
76 let (p, p2) = hessian.dim();
77 if p == 0 || p != p2 {
78 return Err("Cholesky logdet dimension mismatch".to_string());
79 }
80 let solver = DnHandle::new(stream.clone()).map_err(|e| format!("cusolver init: {e}"))?;
81 let h_col = to_col_major(&hessian);
82 let mut h_dev = pinned_htod(&stream, &h_col)?;
83 potrf_in_place(&solver, &stream, p, &mut h_dev)?;
84 let factor_col = stream
85 .clone_dtoh(&h_dev)
86 .map_err(|e| format!("download Cholesky factor: {e}"))?;
87 Ok(cholesky_logdet_from_col_major(&factor_col, p))
88 }
89
90 pub(super) fn cholesky_lower(hessian: ArrayView2<'_, f64>) -> Result<Array2<f64>, String> {
91 let (_, stream) = context_and_stream()?;
92 cholesky_lower_on_stream(hessian, &stream)
93 }
94
95 pub(super) fn cholesky_lower_on_ordinal(
96 ordinal: usize,
97 hessian: ArrayView2<'_, f64>,
98 ) -> Result<Array2<f64>, String> {
99 let (_, stream) = context_and_stream_for(ordinal)?;
100 cholesky_lower_on_stream(hessian, &stream)
101 }
102
103 fn cholesky_lower_on_stream(
104 hessian: ArrayView2<'_, f64>,
105 stream: &std::sync::Arc<cudarc::driver::CudaStream>,
106 ) -> Result<Array2<f64>, String> {
107 let (p, p2) = hessian.dim();
108 if p == 0 || p != p2 {
109 return Err("Cholesky factorization dimension mismatch".to_string());
110 }
111 let solver = DnHandle::new(stream.clone()).map_err(|e| format!("cusolver init: {e}"))?;
112 let h_col = to_col_major(&hessian);
113 let mut h_dev = pinned_htod(&stream, &h_col)?;
114 potrf_in_place(&solver, &stream, p, &mut h_dev)?;
115 let factor_col = stream
116 .clone_dtoh(&h_dev)
117 .map_err(|e| format!("download Cholesky factor: {e}"))?;
118 let mut lower =
119 from_col_major(&factor_col, p, p).ok_or("factor layout conversion failed")?;
120 for row in 0..p {
121 for col in (row + 1)..p {
122 lower[[row, col]] = 0.0;
123 }
124 }
125 Ok(lower)
126 }
127
128 // -----------------------------------------------------------------------
129 // Precision-generic Cholesky scaffold
130 //
131 // POTRF / POTRS host scaffolds are identical across single and double
132 // precision apart from the cuSOLVER symbol called and the device pointer
133 // type. `CholScalar` selects those per-precision pieces so the host-side
134 // allocation / info-handling / error-formatting logic lives once. The
135 // `Dpotr*` (f64) and `Spotr*` (f32) entry points below are thin wrappers
136 // over the generic helpers, preserving their public signatures byte for
137 // byte.
138 // -----------------------------------------------------------------------
139
140 /// cuSOLVER scalar abstraction: selects the precision-specific POTRF/POTRS
141 /// symbols and the precision tag used in deferred-info error messages.
142 ///
143 /// The FFI into cuSOLVER lives inside the trait methods' bodies (in `unsafe`
144 /// blocks), so the trait and its methods are safe to call: each impl wires
145 /// its method bodies to the cuSOLVER entry points whose pointer arguments
146 /// match `Self` (e.g. `cusolverDnDpotrf` for `f64`). Implementors must keep
147 /// that pairing consistent — the device pointer passed in is typed `*mut
148 /// Self` / `*const Self`, so a mismatched symbol would hand cuSOLVER a
149 /// wrongly-typed buffer.
150 pub(crate) trait CholScalar:
151 cudarc::driver::DeviceRepr + cudarc::driver::ValidAsZeroBits + Copy
152 {
153 /// cuSOLVER `*potrf_bufferSize`: `(handle, uplo, n, A, lda, *lwork)`.
154 ///
155 /// `a` is a live `n*n` column-major device buffer of type `Self`,
156 /// `lwork` is a host out-param. The unsafe FFI call is contained in the
157 /// method body.
158 fn potrf_buffer_size(
159 handle: cusolver_sys::cusolverDnHandle_t,
160 uplo: cusolver_sys::cublasFillMode_t,
161 n: i32,
162 a: *mut Self,
163 lda: i32,
164 lwork: *mut i32,
165 ) -> cusolver_sys::cusolverStatus_t;
166 /// cuSOLVER `*potrf`: `(handle, uplo, n, A, lda, work, lwork, info)`.
167 ///
168 /// Pointer args must reference live device buffers of the documented
169 /// shape; the unsafe FFI call is contained in the method body.
170 fn potrf(
171 handle: cusolver_sys::cusolverDnHandle_t,
172 uplo: cusolver_sys::cublasFillMode_t,
173 n: i32,
174 a: *mut Self,
175 lda: i32,
176 work: *mut Self,
177 lwork: i32,
178 info: *mut i32,
179 ) -> cusolver_sys::cusolverStatus_t;
180 /// cuSOLVER `*potrs`: `(handle, uplo, n, nrhs, A, lda, B, ldb, info)`.
181 ///
182 /// Pointer args must reference live device buffers of the documented
183 /// shape; the unsafe FFI call is contained in the method body.
184 fn potrs(
185 handle: cusolver_sys::cusolverDnHandle_t,
186 uplo: cusolver_sys::cublasFillMode_t,
187 n: i32,
188 nrhs: i32,
189 a: *const Self,
190 lda: i32,
191 b: *mut Self,
192 ldb: i32,
193 info: *mut i32,
194 ) -> cusolver_sys::cusolverStatus_t;
195 /// Symbol name fragment for error messages (e.g. `"Dpotrf"`).
196 const POTRF_NAME: &'static str;
197 const POTRS_NAME: &'static str;
198 /// Trailing clause appended to a POTRF "not SPD" error (e.g.
199 /// `" (matrix not SPD at f32)"`); empty for f64.
200 const POTRF_FAIL_SUFFIX: &'static str;
201 }
202
203 impl CholScalar for f64 {
204 fn potrf_buffer_size(
205 handle: cusolver_sys::cusolverDnHandle_t,
206 uplo: cusolver_sys::cublasFillMode_t,
207 n: i32,
208 a: *mut f64,
209 lda: i32,
210 lwork: *mut i32,
211 ) -> cusolver_sys::cusolverStatus_t {
212 // SAFETY: caller guarantees `a` is a live n*n column-major f64 device
213 // buffer and `lwork` is a valid host out-param; symbol matches f64.
214 unsafe { cusolver_sys::cusolverDnDpotrf_bufferSize(handle, uplo, n, a, lda, lwork) }
215 }
216 fn potrf(
217 handle: cusolver_sys::cusolverDnHandle_t,
218 uplo: cusolver_sys::cublasFillMode_t,
219 n: i32,
220 a: *mut f64,
221 lda: i32,
222 work: *mut f64,
223 lwork: i32,
224 info: *mut i32,
225 ) -> cusolver_sys::cusolverStatus_t {
226 // SAFETY: caller guarantees `a` is a live n*n column-major f64 buffer,
227 // `work` was sized by potrf_buffer_size, `info` is a 1-element i32
228 // device buffer; symbol matches f64.
229 unsafe { cusolver_sys::cusolverDnDpotrf(handle, uplo, n, a, lda, work, lwork, info) }
230 }
231 fn potrs(
232 handle: cusolver_sys::cusolverDnHandle_t,
233 uplo: cusolver_sys::cublasFillMode_t,
234 n: i32,
235 nrhs: i32,
236 a: *const f64,
237 lda: i32,
238 b: *mut f64,
239 ldb: i32,
240 info: *mut i32,
241 ) -> cusolver_sys::cusolverStatus_t {
242 // SAFETY: caller guarantees `a` is a live n*n f64 Cholesky factor,
243 // `b` is n*nrhs column-major f64, `info` is a 1-element i32 device
244 // buffer; symbol matches f64.
245 unsafe { cusolver_sys::cusolverDnDpotrs(handle, uplo, n, nrhs, a, lda, b, ldb, info) }
246 }
247 const POTRF_NAME: &'static str = "Dpotrf";
248 const POTRS_NAME: &'static str = "Dpotrs";
249 const POTRF_FAIL_SUFFIX: &'static str = "";
250 }
251
252 impl CholScalar for f32 {
253 fn potrf_buffer_size(
254 handle: cusolver_sys::cusolverDnHandle_t,
255 uplo: cusolver_sys::cublasFillMode_t,
256 n: i32,
257 a: *mut f32,
258 lda: i32,
259 lwork: *mut i32,
260 ) -> cusolver_sys::cusolverStatus_t {
261 // SAFETY: caller guarantees `a` is a live n*n column-major f32 device
262 // buffer and `lwork` is a valid host out-param; symbol matches f32.
263 unsafe { cusolver_sys::cusolverDnSpotrf_bufferSize(handle, uplo, n, a, lda, lwork) }
264 }
265 fn potrf(
266 handle: cusolver_sys::cusolverDnHandle_t,
267 uplo: cusolver_sys::cublasFillMode_t,
268 n: i32,
269 a: *mut f32,
270 lda: i32,
271 work: *mut f32,
272 lwork: i32,
273 info: *mut i32,
274 ) -> cusolver_sys::cusolverStatus_t {
275 // SAFETY: caller guarantees `a` is a live n*n column-major f32 buffer,
276 // `work` was sized by potrf_buffer_size, `info` is a 1-element i32
277 // device buffer; symbol matches f32.
278 unsafe { cusolver_sys::cusolverDnSpotrf(handle, uplo, n, a, lda, work, lwork, info) }
279 }
280 fn potrs(
281 handle: cusolver_sys::cusolverDnHandle_t,
282 uplo: cusolver_sys::cublasFillMode_t,
283 n: i32,
284 nrhs: i32,
285 a: *const f32,
286 lda: i32,
287 b: *mut f32,
288 ldb: i32,
289 info: *mut i32,
290 ) -> cusolver_sys::cusolverStatus_t {
291 // SAFETY: caller guarantees `a` is a live n*n f32 Cholesky factor,
292 // `b` is n*nrhs column-major f32, `info` is a 1-element i32 device
293 // buffer; symbol matches f32.
294 unsafe { cusolver_sys::cusolverDnSpotrs(handle, uplo, n, nrhs, a, lda, b, ldb, info) }
295 }
296 const POTRF_NAME: &'static str = "Spotrf";
297 const POTRS_NAME: &'static str = "Spotrs";
298 const POTRF_FAIL_SUFFIX: &'static str = " (matrix not SPD at f32)";
299 }
300
301 /// Query the cuSOLVER POTRF workspace size (element count) for a p×p
302 /// matrix at precision `T`. Allocates a temporary p×p dummy buffer for the
303 /// query.
304 fn potrf_bufsize_generic<T: CholScalar>(
305 solver: &DnHandle,
306 stream: &std::sync::Arc<cudarc::driver::CudaStream>,
307 p: usize,
308 ) -> Result<usize, String> {
309 let p_i = to_i32(p)?;
310 let uplo = cusolver_sys::cublasFillMode_t::CUBLAS_FILL_MODE_LOWER;
311 let mut lwork = 0_i32;
312 let mut dummy = stream
313 .alloc_zeros::<T>(p.checked_mul(p).ok_or("p² overflow in lwork query")?)
314 .map_err(|e| format!("cuda alloc dummy for lwork query: {e}"))?;
315 {
316 let (ptr, _rec) = dummy.device_ptr_mut(stream);
317 // dummy is a live p*p device buffer of type T, lwork is a host i32;
318 // the unsafe cuSOLVER FFI is contained in T::potrf_buffer_size.
319 let status =
320 T::potrf_buffer_size(solver.cu(), uplo, p_i, ptr as *mut T, p_i, &mut lwork);
321 check_cusolver(status, "cusolverDn*potrf_bufferSize")?;
322 }
323 usize::try_from(lwork).map_err(|_| "negative potrf lwork".to_string())
324 }
325
326 /// Factor a p×p SPD device buffer in-place (lower-triangular Cholesky) at
327 /// precision `T`, querying and allocating its own workspace. Returns `Err`
328 /// if the matrix is singular/indefinite at precision `T`.
329 ///
330 /// This is the single-matrix POTRF core shared across the GPU layer:
331 /// `solver.rs`'s `potrf_in_place`/`spotrf_in_place` and `linalg.rs`'s
332 /// `potrf_lower_in_place` all route through it (the latter mapping the
333 /// `Result` to its `Option` contract at the boundary). The batched POTRF
334 /// (`cusolverDnDpotrfBatched`) in `linalg.rs` is intentionally separate.
335 pub(crate) fn potrf_in_place_generic<T: CholScalar>(
336 solver: &DnHandle,
337 stream: &std::sync::Arc<cudarc::driver::CudaStream>,
338 p: usize,
339 a: &mut CudaSlice<T>,
340 ) -> Result<(), String> {
341 let p_i = to_i32(p)?;
342 let lwork = potrf_bufsize_generic::<T>(solver, stream, p)?;
343 let lwork_i = i32::try_from(lwork).map_err(|_| "negative potrf workspace".to_string())?;
344 let mut workspace = stream
345 .alloc_zeros::<T>(lwork.max(1))
346 .map_err(|e| format!("cuda alloc potrf workspace: {e}"))?;
347 let mut info = stream
348 .alloc_zeros::<i32>(1)
349 .map_err(|e| format!("cuda alloc potrf info: {e}"))?;
350 let uplo = cusolver_sys::cublasFillMode_t::CUBLAS_FILL_MODE_LOWER;
351 {
352 let (a_ptr, _a_rec) = a.device_ptr_mut(stream);
353 let (work_ptr, _work_rec) = workspace.device_ptr_mut(stream);
354 let (info_ptr, _info_rec) = info.device_ptr_mut(stream);
355 // a is p*p col-major T, workspace was sized by T::potrf_buffer_size,
356 // info is a 1-element i32 device buffer; the unsafe cuSOLVER FFI is
357 // contained in T::potrf.
358 let status = T::potrf(
359 solver.cu(),
360 uplo,
361 p_i,
362 a_ptr as *mut T,
363 p_i,
364 work_ptr as *mut T,
365 lwork_i,
366 info_ptr as *mut i32,
367 );
368 check_cusolver(status, "cusolverDn*potrf")?;
369 }
370 let info_host = stream
371 .clone_dtoh(&info)
372 .map_err(|e| format!("download potrf info: {e}"))?;
373 if info_host[0] == 0 {
374 Ok(())
375 } else {
376 Err(format!(
377 "cusolverDn{} returned info={}{}",
378 T::POTRF_NAME,
379 info_host[0],
380 T::POTRF_FAIL_SUFFIX
381 ))
382 }
383 }
384
385 /// Triangular solve using a pre-factored Cholesky lower-triangle at
386 /// precision `T`. Solves `A · x = rhs` in-place into `rhs` (column-major,
387 /// p × nrhs), allocating and downloading its own info scalar.
388 fn potrs_in_place_generic<T: CholScalar>(
389 solver: &DnHandle,
390 stream: &std::sync::Arc<cudarc::driver::CudaStream>,
391 p: usize,
392 nrhs: usize,
393 factor: &CudaSlice<T>,
394 rhs: &mut CudaSlice<T>,
395 ) -> Result<(), String> {
396 let p_i = to_i32(p)?;
397 let nrhs_i = to_i32(nrhs)?;
398 let uplo = cusolver_sys::cublasFillMode_t::CUBLAS_FILL_MODE_LOWER;
399 let mut info = stream
400 .alloc_zeros::<i32>(1)
401 .map_err(|e| format!("cuda alloc potrs info: {e}"))?;
402 {
403 let (f_ptr, _f_rec) = factor.device_ptr(stream);
404 let (r_ptr, _r_rec) = rhs.device_ptr_mut(stream);
405 let (info_ptr, _info_rec) = info.device_ptr_mut(stream);
406 // factor is a p*p lower-triangular T from potrf, rhs is p*nrhs
407 // col-major T, info is a 1-element i32 device buffer; leading dims
408 // match column-major p_i. The unsafe cuSOLVER FFI is contained in
409 // T::potrs.
410 let status = T::potrs(
411 solver.cu(),
412 uplo,
413 p_i,
414 nrhs_i,
415 f_ptr as *const T,
416 p_i,
417 r_ptr as *mut T,
418 p_i,
419 info_ptr as *mut i32,
420 );
421 check_cusolver(status, "cusolverDn*potrs")?;
422 }
423 let info_host = stream
424 .clone_dtoh(&info)
425 .map_err(|e| format!("download potrs info: {e}"))?;
426 if info_host[0] == 0 {
427 Ok(())
428 } else {
429 Err(format!(
430 "cusolverDn{} returned info={}",
431 T::POTRS_NAME,
432 info_host[0]
433 ))
434 }
435 }
436
437 // -----------------------------------------------------------------------
438 // fp32 entry points (thin wrappers over the precision-generic scaffold)
439 // -----------------------------------------------------------------------
440
441 /// Factor a p×p symmetric positive-definite f32 device buffer in-place
442 /// (lower-triangular Cholesky). Returns `Err` if the matrix is
443 /// singular/indefinite.
444 fn spotrf_in_place(
445 solver: &DnHandle,
446 stream: &std::sync::Arc<cudarc::driver::CudaStream>,
447 p: usize,
448 a: &mut CudaSlice<f32>,
449 ) -> Result<(), String> {
450 potrf_in_place_generic::<f32>(solver, stream, p, a)
451 }
452
453 /// Triangular solve using a pre-factored fp32 Cholesky lower-triangle.
454 /// Solves `A · x = rhs` in-place into `rhs` (column-major, p × nrhs).
455 fn spotrs_in_place(
456 solver: &DnHandle,
457 stream: &std::sync::Arc<cudarc::driver::CudaStream>,
458 p: usize,
459 nrhs: usize,
460 factor: &CudaSlice<f32>,
461 rhs: &mut CudaSlice<f32>,
462 ) -> Result<(), String> {
463 potrs_in_place_generic::<f32>(solver, stream, p, nrhs, factor, rhs)
464 }
465
466 // -----------------------------------------------------------------------
467 // fp64 DGEMV residual: r = b − A·x in double precision
468 // -----------------------------------------------------------------------
469
470 /// Compute `r = b − A·x` in fp64 where A is p×p and x, b, r are length p.
471 ///
472 /// Overwrites the output buffer `r_dev` with the residual. Uses
473 /// `cublasDgemv` (CUBLAS_OP_N): `r = 1·A·x + 0·0 = A·x`, then the host
474 /// subtracts from b. Because p is small here (the policy gates on p ≥ 64
475 /// and the Newton system is p×p), downloading the p-vector for the host
476 /// subtract is cheap relative to the GEMV.
477 fn residual_norm_and_vec(
478 blas: &CudaBlas,
479 stream: &std::sync::Arc<cudarc::driver::CudaStream>,
480 p: usize,
481 a_dev: &CudaSlice<f64>,
482 x_dev: &CudaSlice<f64>,
483 b_host: &[f64],
484 ) -> Result<(Vec<f64>, f64), String> {
485 let p_i = to_i32(p)?;
486 // ax_dev = A · x
487 let mut ax_dev = stream
488 .alloc_zeros::<f64>(p)
489 .map_err(|e| format!("alloc ax: {e}"))?;
490 {
491 let cfg = GemvConfig::<f64> {
492 trans: cublas_sys::cublasOperation_t::CUBLAS_OP_N,
493 m: p_i,
494 n: p_i,
495 alpha: 1.0_f64,
496 lda: p_i,
497 incx: 1,
498 beta: 0.0_f64,
499 incy: 1,
500 };
501 // SAFETY: cuBLAS Dgemv; a_dev is p*p col-major f64, x_dev is
502 // length-p f64, ax_dev is length-p output; all on the same stream.
503 unsafe { blas.gemv(cfg, a_dev, x_dev, &mut ax_dev) }
504 .map_err(|e| format!("cublasDgemv for residual: {e}"))?;
505 }
506 let ax_host = stream
507 .clone_dtoh(&ax_dev)
508 .map_err(|e| format!("download A·x: {e}"))?;
509 // r = b − A·x (host subtract; p is small)
510 let r: Vec<f64> = b_host
511 .iter()
512 .zip(ax_host.iter())
513 .map(|(bi, axi)| bi - axi)
514 .collect();
515 let norm_r = r.iter().map(|v| v * v).sum::<f64>().sqrt();
516 Ok((r, norm_r))
517 }
518
519 // -----------------------------------------------------------------------
520 // Iterative refinement: fp32 factor → fp32 solve → fp64 residual loop
521 // -----------------------------------------------------------------------
522
523 /// Solve `A x = b` using an fp32 Cholesky factorization with up to
524 /// `max_steps` fp64-residual iterative refinement corrections.
525 ///
526 /// # Algorithm
527 ///
528 /// 1. Cast `A` (f64) to f32 on device. Factor in fp32 (POTRF).
529 /// 2. Cast `b` (f64) to f32. Solve `A x = b` in fp32 (POTRS). Lift `x`
530 /// to f64.
531 /// 3. Loop up to `max_steps`:
532 /// a. `r = b − A·x` accumulated in fp64 (cuBLAS Dgemv).
533 /// b. `‖r‖ / ‖b‖ ≤ tol` → converged, break.
534 /// c. Residual did not drop below previous step → bail, return `Err`.
535 /// d. Cast `r` to f32. Solve `A e = r` in fp32. `x += e` (f64).
536 /// 4. Return `(x, ‖r‖/‖b‖, refinement_steps)`.
537 ///
538 /// Returns `Err` when the fp32 POTRF fails (not SPD at f32) or when the
539 /// residual does not decrease monotonically (κ(A)·u_f32 ≥ 1 regime).
540 /// Callers should fall back to fp64 POTRF on `Err`.
541 pub(super) fn iterative_refinement_solve_impl(
542 hessian: ArrayView2<'_, f64>,
543 rhs: &[f64],
544 ) -> Result<super::RefinementOutcome, String> {
545 use crate::policy::GpuDispatchPolicy;
546 let (p, p2) = hessian.dim();
547 if p == 0 || p != p2 || rhs.len() != p {
548 return Err("iterative_refinement_solve: dimension mismatch".to_string());
549 }
550 let max_steps = GpuDispatchPolicy::REFINEMENT_MAX_STEPS;
551 let tol = GpuDispatchPolicy::REFINEMENT_TOL;
552
553 let (_, stream) = context_and_stream()?;
554 let solver = DnHandle::new(stream.clone()).map_err(|e| format!("cusolver init: {e}"))?;
555 let blas = CudaBlas::new(stream.clone()).map_err(|e| format!("cublas init: {e}"))?;
556
557 // Upload fp64 hessian for residual GEMV.
558 let h_col_f64 = to_col_major(&hessian);
559 let a_dev_f64 = pinned_htod(&stream, &h_col_f64)?;
560
561 // Cast A to f32 and upload.
562 let h_col_f32: Vec<f32> = h_col_f64.iter().map(|&v| v as f32).collect();
563 let mut a_dev_f32 =
564 pinned_htod(&stream, &h_col_f32).map_err(|e| format!("upload f32 A: {e}"))?;
565
566 // fp32 POTRF — returns Err if A is not SPD at f32 precision.
567 spotrf_in_place(&solver, &stream, p, &mut a_dev_f32)?;
568
569 // Cast b to f32 and upload; solve in fp32.
570 let b_f32: Vec<f32> = rhs.iter().map(|&v| v as f32).collect();
571 let mut x_dev_f32 =
572 pinned_htod(&stream, &b_f32).map_err(|e| format!("upload f32 rhs: {e}"))?;
573 spotrs_in_place(&solver, &stream, p, 1, &a_dev_f32, &mut x_dev_f32)?;
574
575 // Lift x to f64.
576 let x_f32 = stream
577 .clone_dtoh(&x_dev_f32)
578 .map_err(|e| format!("download f32 x: {e}"))?;
579 let mut x: Vec<f64> = x_f32.iter().map(|&v| v as f64).collect();
580
581 // Compute ‖b‖ for relative residual.
582 let norm_b = rhs.iter().map(|v| v * v).sum::<f64>().sqrt();
583 let norm_b_safe = if norm_b > 0.0 { norm_b } else { 1.0 };
584
585 let mut x_dev_f64 = pinned_htod(&stream, &x).map_err(|e| format!("upload f64 x: {e}"))?;
586 let (r0, norm_r0) = residual_norm_and_vec(&blas, &stream, p, &a_dev_f64, &x_dev_f64, rhs)?;
587 let mut rel_residual = norm_r0 / norm_b_safe;
588
589 // Early exit: already converged after initial solve.
590 if rel_residual <= tol {
591 return Ok(super::RefinementOutcome {
592 solution: ndarray::Array1::from_vec(x),
593 relative_residual: rel_residual,
594 used_fp32_factor: true,
595 refinement_steps: 0,
596 });
597 }
598
599 let mut r = r0;
600 let mut prev_norm_r = norm_r0;
601 let mut steps_taken = 0_usize;
602
603 for _ in 0..max_steps {
604 // Cast residual to f32, solve A e = r in fp32.
605 let r_f32: Vec<f32> = r.iter().map(|&v| v as f32).collect();
606 let mut e_dev_f32 =
607 pinned_htod(&stream, &r_f32).map_err(|e| format!("upload f32 residual: {e}"))?;
608 spotrs_in_place(&solver, &stream, p, 1, &a_dev_f32, &mut e_dev_f32)?;
609
610 // x += e in f64.
611 let e_f32 = stream
612 .clone_dtoh(&e_dev_f32)
613 .map_err(|e| format!("download f32 e: {e}"))?;
614 for (xi, ei) in x.iter_mut().zip(e_f32.iter()) {
615 *xi += *ei as f64;
616 }
617 steps_taken += 1;
618
619 // Reupload x_dev_f64 and compute new residual.
620 x_dev_f64 = pinned_htod(&stream, &x).map_err(|e| format!("upload refined x: {e}"))?;
621 let (r_new, norm_r_new) =
622 residual_norm_and_vec(&blas, &stream, p, &a_dev_f64, &x_dev_f64, rhs)?;
623 rel_residual = norm_r_new / norm_b_safe;
624
625 // Check monotone decrease. Non-monotone → κ(A)·u ≥ 1.
626 if norm_r_new >= prev_norm_r {
627 return Err(format!(
628 "iterative refinement: residual not decreasing ({norm_r_new:.3e} ≥ {prev_norm_r:.3e}); \
629 κ(A)·u_f32 ≥ 1, cannot refine"
630 ));
631 }
632 prev_norm_r = norm_r_new;
633 r = r_new;
634
635 if rel_residual <= tol {
636 break;
637 }
638 }
639
640 Ok(super::RefinementOutcome {
641 solution: ndarray::Array1::from_vec(x),
642 relative_residual: rel_residual,
643 used_fp32_factor: true,
644 refinement_steps: steps_taken,
645 })
646 }
647
648 /// Bind a specific device ordinal's cached context on the calling thread and
649 /// open a fresh stream on it. This is the per-ordinal entry point used by
650 /// multi-GPU fan-out (`crate::pool::scatter_batched` workers) so a
651 /// Cholesky / TRSM can target the device the worker thread owns. The
652 /// primary-device convenience wrapper [`context_and_stream`] calls this with
653 /// the probe-selected ordinal.
654 pub(crate) fn context_and_stream_for(
655 ordinal: usize,
656 ) -> Result<
657 (
658 std::sync::Arc<CudaContext>,
659 std::sync::Arc<cudarc::driver::CudaStream>,
660 ),
661 String,
662 > {
663 let ctx = super::super::device_runtime::cuda_context_for(ordinal)
664 .ok_or_else(|| format!("cuda context for ordinal {ordinal} unavailable"))?;
665 ctx.bind_to_thread()
666 .map_err(|e| format!("cuda context bind_to_thread: {e}"))?;
667 let stream = ctx.new_stream().map_err(|e| format!("cuda stream: {e}"))?;
668 Ok((ctx, stream))
669 }
670
671 pub fn context_and_stream() -> Result<
672 (
673 std::sync::Arc<CudaContext>,
674 std::sync::Arc<cudarc::driver::CudaStream>,
675 ),
676 String,
677 > {
678 // Route through the runtime's cached primary context for the selected
679 // device so every CUDA client in the process (calibration, session,
680 // cuSolver) shares one CUcontext per ordinal. Falling back to
681 // `CudaContext::new(0)` here would fragment driver state across
682 // distinct contexts, defeat memory-pool sharing, and pin work to
683 // ordinal 0 even when the runtime probe chose a different device.
684 let runtime = super::super::device_runtime::GpuRuntime::global()
685 .ok_or_else(|| "cuda runtime unavailable".to_string())?;
686 context_and_stream_for(runtime.selected_device().ordinal)
687 }
688
689 pub fn pinned_htod<
690 T: cudarc::driver::DeviceRepr + cudarc::driver::ValidAsZeroBits + Copy,
691 >(
692 stream: &std::sync::Arc<cudarc::driver::CudaStream>,
693 src: &[T],
694 ) -> Result<CudaSlice<T>, String> {
695 // Originally this routine round-tripped the upload through a
696 // `CU_MEMHOSTALLOC_WRITECOMBINED` pinned staging buffer
697 // (`ctx.alloc_pinned`) to enable async DMA. In cudarc 0.19 the
698 // `PinnedHostSlice` returned from `alloc_pinned` carries an event that
699 // its `Drop` impl unconditionally `event.synchronize()`s before freeing
700 // the host mapping — see cudarc-0.19.7 `core.rs::PinnedHostSlice::drop`.
701 // Because the staging buffer goes out of scope at the end of this
702 // function, the host thread blocks here until the H2D copy completes,
703 // immediately defeating the "async" of pinned DMA. The net cost is two
704 // extra driver calls per upload (`cuMemHostAlloc_WC` + `cuMemFreeHost`)
705 // plus a forced stream synchronization, and the workspace ends up
706 // strictly slower than a plain pageable H2D — the driver already
707 // stages pageable copies internally via its own pinned pool, and that
708 // path does not block the issuing host thread for unrelated stream
709 // work. Issue a direct async H2D from the pageable buffer instead.
710 stream.clone_htod(src).map_err(|e| format!("cuda H2D: {e}"))
711 }
712
713 pub fn potrf_in_place(
714 solver: &DnHandle,
715 stream: &std::sync::Arc<cudarc::driver::CudaStream>,
716 p: usize,
717 h: &mut CudaSlice<f64>,
718 ) -> Result<(), String> {
719 potrf_in_place_generic::<f64>(solver, stream, p, h)
720 }
721
722 pub fn potrs_in_place(
723 solver: &DnHandle,
724 stream: &std::sync::Arc<cudarc::driver::CudaStream>,
725 p: usize,
726 nrhs: usize,
727 h: &CudaSlice<f64>,
728 rhs: &mut CudaSlice<f64>,
729 ) -> Result<(), String> {
730 potrs_in_place_generic::<f64>(solver, stream, p, nrhs, h, rhs)
731 }
732
733 /// Query the cuSOLVER POTRF workspace size for a p×p matrix.
734 ///
735 /// Called once at workspace construction to size the persistent workspace
736 /// buffer. Returns the number of f64 elements required.
737 pub fn potrf_query_lwork(
738 solver: &DnHandle,
739 stream: &std::sync::Arc<cudarc::driver::CudaStream>,
740 p: usize,
741 ) -> Result<usize, String> {
742 potrf_bufsize_generic::<f64>(solver, stream, p)
743 }
744
745 /// POTRF factorization using pre-allocated workspace and info buffers.
746 ///
747 /// Does not allocate, does not download `info`. The caller is responsible
748 /// for calling [`check_deferred_potrf_info`] at end-of-fit to confirm no
749 /// factorization failed.
750 ///
751 /// `workspace` must have been allocated with at least `lwork` elements
752 /// (as reported by [`potrf_query_lwork`] at workspace construction).
753 /// `info_dev` is a 1-element device i32 buffer; after a failed
754 /// factorization it holds a positive integer but stays device-resident.
755 pub fn potrf_in_place_reuse(
756 solver: &DnHandle,
757 stream: &std::sync::Arc<cudarc::driver::CudaStream>,
758 p: usize,
759 lwork: i32,
760 h: &mut CudaSlice<f64>,
761 workspace: &mut CudaSlice<f64>,
762 info_dev: &mut CudaSlice<i32>,
763 ) -> Result<(), String> {
764 let p_i = to_i32(p)?;
765 let uplo = cusolver_sys::cublasFillMode_t::CUBLAS_FILL_MODE_LOWER;
766 {
767 let (h_ptr, _h_record) = h.device_ptr_mut(stream);
768 let (work_ptr, _work_record) = workspace.device_ptr_mut(stream);
769 let (info_ptr, _info_record) = info_dev.device_ptr_mut(stream);
770 // SAFETY: cuSOLVER potrf; h is p*p col-major, workspace was sized
771 // by potrf_query_lwork, info_dev is a pre-allocated 1-element i32
772 // device buffer. All buffers are live on the same stream.
773 let status = unsafe {
774 cusolver_sys::cusolverDnDpotrf(
775 solver.cu(),
776 uplo,
777 p_i,
778 h_ptr as *mut f64,
779 p_i,
780 work_ptr as *mut f64,
781 lwork,
782 info_ptr as *mut i32,
783 )
784 };
785 check_cusolver(status, "cusolverDnDpotrf")?;
786 }
787 Ok(())
788 }
789
790 /// POTRS triangular solve using a pre-allocated info buffer.
791 ///
792 /// Does not allocate, does not download `info`. The caller is responsible
793 /// for calling [`check_deferred_potrs_info`] at end-of-fit.
794 pub fn potrs_in_place_reuse(
795 solver: &DnHandle,
796 stream: &std::sync::Arc<cudarc::driver::CudaStream>,
797 p: usize,
798 nrhs: usize,
799 h: &CudaSlice<f64>,
800 rhs: &mut CudaSlice<f64>,
801 info_dev: &mut CudaSlice<i32>,
802 ) -> Result<(), String> {
803 let p_i = to_i32(p)?;
804 let nrhs_i = to_i32(nrhs)?;
805 let uplo = cusolver_sys::cublasFillMode_t::CUBLAS_FILL_MODE_LOWER;
806 {
807 let (h_ptr, _h_record) = h.device_ptr(stream);
808 let (rhs_ptr, _rhs_record) = rhs.device_ptr_mut(stream);
809 let (info_ptr, _info_record) = info_dev.device_ptr_mut(stream);
810 // SAFETY: cuSOLVER potrs; h is a p*p Cholesky factor, rhs is p*nrhs,
811 // info_dev is a pre-allocated 1-element i32 device buffer.
812 let status = unsafe {
813 cusolver_sys::cusolverDnDpotrs(
814 solver.cu(),
815 uplo,
816 p_i,
817 nrhs_i,
818 h_ptr as *const f64,
819 p_i,
820 rhs_ptr as *mut f64,
821 p_i,
822 info_ptr as *mut i32,
823 )
824 };
825 check_cusolver(status, "cusolverDnDpotrs")?;
826 }
827 Ok(())
828 }
829
830 /// Download the POTRF deferred info scalar and return an error if non-zero.
831 ///
832 /// Called once at end-of-fit (or whenever the convergence loop exits) to
833 /// surface any factorization failure that was deferred device-side by
834 /// [`potrf_in_place_reuse`].
835 pub fn check_deferred_potrf_info(
836 stream: &std::sync::Arc<cudarc::driver::CudaStream>,
837 info_dev: &CudaSlice<i32>,
838 ) -> Result<(), String> {
839 let info_host = stream
840 .clone_dtoh(info_dev)
841 .map_err(|e| format!("download deferred potrf info: {e}"))?;
842 if info_host[0] == 0 {
843 Ok(())
844 } else {
845 Err(format!(
846 "cusolverDnDpotrf returned info={} (detected at end-of-fit)",
847 info_host[0]
848 ))
849 }
850 }
851
852 /// Download the POTRS deferred info scalar and return an error if non-zero.
853 ///
854 /// Mirrors [`check_deferred_potrf_info`] for the triangular-solve step.
855 pub fn check_deferred_potrs_info(
856 stream: &std::sync::Arc<cudarc::driver::CudaStream>,
857 info_dev: &CudaSlice<i32>,
858 ) -> Result<(), String> {
859 let info_host = stream
860 .clone_dtoh(info_dev)
861 .map_err(|e| format!("download deferred potrs info: {e}"))?;
862 if info_host[0] == 0 {
863 Ok(())
864 } else {
865 Err(format!(
866 "cusolverDnDpotrs returned info={} (detected at end-of-fit)",
867 info_host[0]
868 ))
869 }
870 }
871
872 pub fn cholesky_logdet_from_col_major(factor: &[f64], p: usize) -> f64 {
873 let factor = MatRef::from_column_major_slice(factor, p, p);
874 cholesky_factor_logdet(factor)
875 }
876
877 fn check_cusolver(
878 status: cusolver_sys::cusolverStatus_t,
879 label: &'static str,
880 ) -> Result<(), String> {
881 if status == cusolver_sys::cusolverStatus_t::CUSOLVER_STATUS_SUCCESS {
882 Ok(())
883 } else {
884 Err(format!("{label} failed with {status:?}"))
885 }
886 }
887
888 fn to_i32(value: usize) -> Result<i32, String> {
889 i32::try_from(value).map_err(|_| format!("CUDA dimension {value} exceeds i32"))
890 }
891}
892
893// These solver entry points are consumed by sibling crates (`gam-solve`'s
894// pirls/reml GPU paths, `gam-models`, ...) via `gam_gpu::solver::*`, so they
895// are part of gam-gpu's public surface. `potrf_in_place_generic` is the
896// only one with no cross-crate consumer; it stays crate-private and is
897// reached internally through `crate::solver::potrf_in_place_generic`.
898#[cfg(target_os = "linux")]
899pub use cuda::{
900 check_deferred_potrf_info, check_deferred_potrs_info, cholesky_logdet_from_col_major,
901 context_and_stream, pinned_htod, potrf_in_place, potrf_in_place_reuse, potrf_query_lwork,
902 potrs_in_place, potrs_in_place_reuse,
903};
904#[cfg(target_os = "linux")]
905pub(crate) use cuda::potrf_in_place_generic;
906
907/// Solve `A x = b` with fp32 Cholesky factorization + fp64-residual iterative
908/// refinement, automatically falling back to fp64 when the policy rejects the
909/// attempt or when the fp32 path fails / diverges.
910///
911/// The `p` threshold and maximum step count come from [`GpuDispatchPolicy`]
912/// constants — there is no user-facing knob. The decision path is:
913///
914/// 1. `policy.iterative_refinement_should_attempt(p)` → `false` or
915/// multi-column RHS: skip to the fp64 Cholesky path.
916/// 2. Attempt fp32 POTRF + up to `REFINEMENT_MAX_STEPS` residual-correction
917/// steps. Falls back to fp64 on:
918/// - fp32 POTRF info ≠ 0 (A is not SPD at f32 precision),
919/// - non-monotone residual (κ(A)·u_fp32 ≥ 1 regime).
920/// 3. On fp32 success the logdet is computed from the fp64 Cholesky factor —
921/// BUT only when `need_logdet` is true. The fp64 POTRF is an O(p³)
922/// factorization that fully negates the mixed-precision speedup (the whole
923/// point is to do the expensive factor in fp32), so a caller that only needs
924/// the *solution* (e.g. the PIRLS Newton direction solve, which discards the
925/// logdet) passes `need_logdet = false` and the redundant fp64 POTRF is
926/// skipped entirely — the returned logdet is `NaN` in that case. The solution
927/// is always full-fp64-accurate via the residual refinement regardless.
928///
929/// Returns `(solution, logdet, Some(RefinementOutcome))` when the fp32 path
930/// succeeded, or `(solution, logdet, None)` on the fp64 fallback. When
931/// `need_logdet` is false and the fp32 path succeeds, the logdet field is `NaN`.
932pub fn iterative_refinement_cholesky_solve(
933 hessian: ArrayView2<'_, f64>,
934 rhs: ArrayView2<'_, f64>,
935 need_logdet: bool,
936) -> Result<(Array2<f64>, f64, Option<RefinementOutcome>), String> {
937 #[cfg(not(target_os = "linux"))]
938 {
939 let (rows, cols) = hessian.dim();
940 return Err(format!(
941 "CUDA support not compiled; hessian={rows}x{cols}, rhs={}x{}, need_logdet={need_logdet}",
942 rhs.nrows(),
943 rhs.ncols()
944 ));
945 }
946
947 #[cfg(target_os = "linux")]
948 {
949 let runtime = super::device_runtime::GpuRuntime::global().ok_or_else(|| {
950 let (rows, cols) = hessian.dim();
951 format!(
952 "CUDA runtime unavailable; hessian={rows}x{cols}, rhs={}x{}",
953 rhs.nrows(),
954 rhs.ncols()
955 )
956 })?;
957 let p = hessian.nrows();
958
959 // Attempt fp32 + refinement only for single-column RHS with p large
960 // enough that the fp64 GEMV residual cost is amortised.
961 if rhs.ncols() == 1 && runtime.policy.iterative_refinement_should_attempt(p) {
962 let rhs_col = rhs.column(0);
963 let rhs_slice: Vec<f64> = rhs_col.iter().copied().collect();
964 if let Ok(outcome) = cuda::iterative_refinement_solve_impl(hessian, &rhs_slice) {
965 // fp32 + refinement succeeded; the refined solution is full
966 // fp64 accuracy. The logdet, however, needs the fp64 Cholesky
967 // factor (the fp32 diagonal is only fp32-accurate, and the
968 // logdet feeds the REML criterion / EDF). Run the fp64 POTRF
969 // ONLY when the caller actually consumes the logdet: otherwise
970 // that O(p³) factorization is pure overhead that cancels the
971 // mixed-precision win (the expensive factor would then run in
972 // BOTH precisions). A solution-only caller (PIRLS Newton
973 // direction, which discards the logdet) gets the genuine
974 // fp32-factor speedup; logdet is reported as NaN.
975 let mut sol = Array2::<f64>::zeros((p, 1));
976 sol.column_mut(0).assign(&outcome.solution);
977 if !need_logdet {
978 return Ok((sol, f64::NAN, Some(outcome)));
979 }
980 if let Ok(logdet) = cuda::cholesky_logdet(hessian) {
981 return Ok((sol, logdet, Some(outcome)));
982 }
983 // fp64 logdet failed (theoretically impossible for SPD A);
984 // fall through to plain fp64 path.
985 }
986 // fp32 path failed (not SPD at f32, or residual non-monotone) →
987 // fall through to fp64.
988 }
989
990 let (sol, logdet) = cuda::cholesky_solve(hessian, rhs)?;
991 Ok((sol, logdet, None))
992 }
993}
994
995pub fn cholesky_solve_gpu(
996 hessian: ArrayView2<'_, f64>,
997 rhs: ArrayView2<'_, f64>,
998) -> Result<(Array2<f64>, f64), String> {
999 // Route through iterative refinement. The function falls back to fp64
1000 // internally, so callers always get a valid result; the refinement
1001 // outcome metadata is intentionally not surfaced by this thin wrapper.
1002 // This wrapper returns the logdet, so it must request it (`need_logdet`).
1003 let result = iterative_refinement_cholesky_solve(hessian, rhs, /*need_logdet=*/ true)?;
1004 Ok((result.0, result.1))
1005}
1006
1007/// Solution-only mixed-precision solve: like [`cholesky_solve_gpu`] but skips
1008/// the redundant fp64 POTRF when the fp32 + refinement path succeeds, since the
1009/// caller does not consume the log-determinant. This is the path that delivers
1010/// the full mixed-precision speedup (expensive O(p³) factor stays fp32) for the
1011/// PIRLS Newton direction solve, where the logdet is discarded. The solution is
1012/// full fp64 accuracy via iterative refinement.
1013pub fn cholesky_solve_only_gpu(
1014 hessian: ArrayView2<'_, f64>,
1015 rhs: ArrayView2<'_, f64>,
1016) -> Result<Array2<f64>, String> {
1017 let result = iterative_refinement_cholesky_solve(hessian, rhs, /*need_logdet=*/ false)?;
1018 Ok(result.0)
1019}
1020
1021pub fn cholesky_lower_gpu(hessian: ArrayView2<'_, f64>) -> Result<Array2<f64>, String> {
1022 #[cfg(not(target_os = "linux"))]
1023 {
1024 let (rows, cols) = hessian.dim();
1025 return Err(format!(
1026 "CUDA support not compiled for Cholesky factorization; hessian={rows}x{cols}"
1027 ));
1028 }
1029
1030 #[cfg(target_os = "linux")]
1031 {
1032 if super::device_runtime::GpuRuntime::global().is_none() {
1033 let (rows, cols) = hessian.dim();
1034 return Err(format!(
1035 "CUDA runtime unavailable for Cholesky factorization; hessian={rows}x{cols}"
1036 ));
1037 }
1038 cuda::cholesky_lower(hessian)
1039 }
1040}
1041
1042#[cfg(target_os = "linux")]
1043pub(crate) fn cholesky_lower_on_ordinal_gpu(
1044 ordinal: usize,
1045 hessian: ArrayView2<'_, f64>,
1046) -> Result<Array2<f64>, String> {
1047 cuda::cholesky_lower_on_ordinal(ordinal, hessian)
1048}