Skip to main content

gam_solve/gpu/
reml_gpu.rs

1//! Exact GPU REML evidence + derivative gradient.
2//!
3//! Refactor (Block 2.1, math team section 18): the penalized Hessian `H` is
4//! Cholesky-factored exactly **once** on device, the factor is held resident,
5//! and every derivative Hessian `H_j` is solved through the cached factor
6//! with a single batched `potrs` call (`nrhs = d_rho · p`). Previously each
7//! derivative re-issued the full `cholesky_solve_gpu` path, which uploaded
8//! `H`, allocated and ran `potrf`, and downloaded the factor again — turning
9//! a `p^3 + d·p^3` workload into a `(d+1)·p^3` one and serializing `d_rho`
10//! factor passes onto the device.
11//!
12//! On the non-Linux fallback the same `cholesky_solve_gpu` path is exercised
13//! via `pirls_gpu::cholesky_solve_gpu`, so behaviour outside Linux is
14//! numerically identical (with the same per-derivative overhead) — the
15//! optimisation is Linux-only because that is where CUDA actually runs.
16
17use ndarray::{Array1, ArrayView2};
18
19#[derive(Clone, Debug)]
20pub struct RemlGpuInput<'a> {
21    pub penalized_hessian: ArrayView2<'a, f64>,
22    pub derivative_hessians: Vec<ArrayView2<'a, f64>>,
23}
24
25#[derive(Clone, Debug)]
26pub struct RemlGpuEvidence {
27    pub logdet_hessian: f64,
28    pub gradient_rho: Array1<f64>,
29}
30
31pub fn evidence_derivatives_gpu(input: RemlGpuInput<'_>) -> Result<RemlGpuEvidence, String> {
32    let p = input.penalized_hessian.nrows();
33    if p != input.penalized_hessian.ncols() {
34        return Err("REML GPU Hessian must be square".to_string());
35    }
36    for (j, derivative) in input.derivative_hessians.iter().enumerate() {
37        if derivative.dim() != (p, p) {
38            return Err(format!(
39                "REML derivative Hessian {j} has shape {:?}, expected {p}x{p}",
40                derivative.dim()
41            ));
42        }
43    }
44
45    #[cfg(target_os = "linux")]
46    {
47        if gam_gpu::device_runtime::GpuRuntime::global().is_some() {
48            return linux_cuda::evidence_derivatives(input);
49        }
50    }
51
52    cpu_fallback::evidence_derivatives(input)
53}
54
55#[cfg(target_os = "linux")]
56mod linux_cuda {
57    use super::{RemlGpuEvidence, RemlGpuInput};
58    use gam_gpu::driver::to_col_major;
59    use gam_gpu::solver::{
60        cholesky_logdet_from_col_major, context_and_stream, pinned_htod, potrf_in_place,
61        potrs_in_place,
62    };
63    use cudarc::cusolver::DnHandle;
64    use ndarray::Array1;
65
66    pub(super) fn evidence_derivatives(input: RemlGpuInput<'_>) -> Result<RemlGpuEvidence, String> {
67        let p = input.penalized_hessian.nrows();
68        let d = input.derivative_hessians.len();
69        let (_, stream) = context_and_stream()?;
70        let solver = DnHandle::new(stream.clone()).map_err(|e| format!("cusolver init: {e}"))?;
71
72        // Upload H once and factor in-place.
73        let h_col = to_col_major(&input.penalized_hessian);
74        let mut h_dev = pinned_htod(&stream, &h_col)?;
75        potrf_in_place(&solver, &stream, p, &mut h_dev)?;
76        let factor_col = stream
77            .clone_dtoh(&h_dev)
78            .map_err(|e| format!("download Cholesky factor: {e}"))?;
79        let logdet_hessian = cholesky_logdet_from_col_major(&factor_col, p);
80
81        if d == 0 {
82            return Ok(RemlGpuEvidence {
83                logdet_hessian,
84                gradient_rho: Array1::<f64>::zeros(0),
85            });
86        }
87
88        // Stack all derivative Hessians column-wise into ONE rhs of width d*p
89        // and solve with a single batched potrs against the cached factor.
90        let total_cols = p
91            .checked_mul(d)
92            .ok_or_else(|| format!("REML GPU RHS width overflow: p={p}, d={d}"))?;
93        let total_elems = p
94            .checked_mul(total_cols)
95            .ok_or_else(|| format!("REML GPU RHS size overflow: p={p}, cols={total_cols}"))?;
96        let mut rhs_col = Vec::<f64>::with_capacity(total_elems);
97        for derivative in &input.derivative_hessians {
98            let col = to_col_major(derivative);
99            rhs_col.extend_from_slice(&col);
100        }
101        let mut rhs_dev = pinned_htod(&stream, &rhs_col)?;
102        potrs_in_place(&solver, &stream, p, total_cols, &h_dev, &mut rhs_dev)?;
103        let solved_col = stream
104            .clone_dtoh(&rhs_dev)
105            .map_err(|e| format!("download REML derivative solves: {e}"))?;
106
107        let mut gradient_rho = Array1::<f64>::zeros(d);
108        for j in 0..d {
109            let offset = j * p * p;
110            // Diagonal of H^{-1} A_j is the diagonal of the j-th p*p slab.
111            let mut trace = 0.0_f64;
112            for i in 0..p {
113                trace += solved_col[offset + i * p + i];
114            }
115            gradient_rho[j] = 0.5 * trace;
116        }
117
118        Ok(RemlGpuEvidence {
119            logdet_hessian,
120            gradient_rho,
121        })
122    }
123}
124
125mod cpu_fallback {
126    use super::{RemlGpuEvidence, RemlGpuInput};
127    use ndarray::{Array1, Array2};
128
129    pub(super) fn evidence_derivatives(input: RemlGpuInput<'_>) -> Result<RemlGpuEvidence, String> {
130        let p = input.penalized_hessian.nrows();
131        let mut identity = Array2::<f64>::zeros((p, p));
132        for i in 0..p {
133            identity[[i, i]] = 1.0;
134        }
135        let (_, logdet_hessian) = crate::gpu::pirls_gpu::cholesky_solve_gpu(
136            input.penalized_hessian,
137            identity.view(),
138        )?;
139        let mut gradient_rho = Array1::<f64>::zeros(input.derivative_hessians.len());
140        for (j, derivative) in input.derivative_hessians.iter().enumerate() {
141            let (solved, _) = crate::gpu::pirls_gpu::cholesky_solve_gpu(
142                input.penalized_hessian,
143                derivative.view(),
144            )?;
145            let mut trace = 0.0_f64;
146            for i in 0..p {
147                trace += solved[[i, i]];
148            }
149            gradient_rho[j] = 0.5 * trace;
150        }
151        Ok(RemlGpuEvidence {
152            logdet_hessian,
153            gradient_rho,
154        })
155    }
156}