gam_solve/gpu/
reml_gpu.rs1use ndarray::{Array1, ArrayView2};
18
19#[derive(Clone, Debug)]
20pub struct RemlGpuInput<'a> {
21 pub penalized_hessian: ArrayView2<'a, f64>,
22 pub derivative_hessians: Vec<ArrayView2<'a, f64>>,
23}
24
25#[derive(Clone, Debug)]
26pub struct RemlGpuEvidence {
27 pub logdet_hessian: f64,
28 pub gradient_rho: Array1<f64>,
29}
30
31pub fn evidence_derivatives_gpu(input: RemlGpuInput<'_>) -> Result<RemlGpuEvidence, String> {
32 let p = input.penalized_hessian.nrows();
33 if p != input.penalized_hessian.ncols() {
34 return Err("REML GPU Hessian must be square".to_string());
35 }
36 for (j, derivative) in input.derivative_hessians.iter().enumerate() {
37 if derivative.dim() != (p, p) {
38 return Err(format!(
39 "REML derivative Hessian {j} has shape {:?}, expected {p}x{p}",
40 derivative.dim()
41 ));
42 }
43 }
44
45 #[cfg(target_os = "linux")]
46 {
47 if gam_gpu::device_runtime::GpuRuntime::global().is_some() {
48 return linux_cuda::evidence_derivatives(input);
49 }
50 }
51
52 cpu_fallback::evidence_derivatives(input)
53}
54
55#[cfg(target_os = "linux")]
56mod linux_cuda {
57 use super::{RemlGpuEvidence, RemlGpuInput};
58 use gam_gpu::driver::to_col_major;
59 use gam_gpu::solver::{
60 cholesky_logdet_from_col_major, context_and_stream, pinned_htod, potrf_in_place,
61 potrs_in_place,
62 };
63 use cudarc::cusolver::DnHandle;
64 use ndarray::Array1;
65
66 pub(super) fn evidence_derivatives(input: RemlGpuInput<'_>) -> Result<RemlGpuEvidence, String> {
67 let p = input.penalized_hessian.nrows();
68 let d = input.derivative_hessians.len();
69 let (_, stream) = context_and_stream()?;
70 let solver = DnHandle::new(stream.clone()).map_err(|e| format!("cusolver init: {e}"))?;
71
72 let h_col = to_col_major(&input.penalized_hessian);
74 let mut h_dev = pinned_htod(&stream, &h_col)?;
75 potrf_in_place(&solver, &stream, p, &mut h_dev)?;
76 let factor_col = stream
77 .clone_dtoh(&h_dev)
78 .map_err(|e| format!("download Cholesky factor: {e}"))?;
79 let logdet_hessian = cholesky_logdet_from_col_major(&factor_col, p);
80
81 if d == 0 {
82 return Ok(RemlGpuEvidence {
83 logdet_hessian,
84 gradient_rho: Array1::<f64>::zeros(0),
85 });
86 }
87
88 let total_cols = p
91 .checked_mul(d)
92 .ok_or_else(|| format!("REML GPU RHS width overflow: p={p}, d={d}"))?;
93 let total_elems = p
94 .checked_mul(total_cols)
95 .ok_or_else(|| format!("REML GPU RHS size overflow: p={p}, cols={total_cols}"))?;
96 let mut rhs_col = Vec::<f64>::with_capacity(total_elems);
97 for derivative in &input.derivative_hessians {
98 let col = to_col_major(derivative);
99 rhs_col.extend_from_slice(&col);
100 }
101 let mut rhs_dev = pinned_htod(&stream, &rhs_col)?;
102 potrs_in_place(&solver, &stream, p, total_cols, &h_dev, &mut rhs_dev)?;
103 let solved_col = stream
104 .clone_dtoh(&rhs_dev)
105 .map_err(|e| format!("download REML derivative solves: {e}"))?;
106
107 let mut gradient_rho = Array1::<f64>::zeros(d);
108 for j in 0..d {
109 let offset = j * p * p;
110 let mut trace = 0.0_f64;
112 for i in 0..p {
113 trace += solved_col[offset + i * p + i];
114 }
115 gradient_rho[j] = 0.5 * trace;
116 }
117
118 Ok(RemlGpuEvidence {
119 logdet_hessian,
120 gradient_rho,
121 })
122 }
123}
124
125mod cpu_fallback {
126 use super::{RemlGpuEvidence, RemlGpuInput};
127 use ndarray::{Array1, Array2};
128
129 pub(super) fn evidence_derivatives(input: RemlGpuInput<'_>) -> Result<RemlGpuEvidence, String> {
130 let p = input.penalized_hessian.nrows();
131 let mut identity = Array2::<f64>::zeros((p, p));
132 for i in 0..p {
133 identity[[i, i]] = 1.0;
134 }
135 let (_, logdet_hessian) = crate::gpu::pirls_gpu::cholesky_solve_gpu(
136 input.penalized_hessian,
137 identity.view(),
138 )?;
139 let mut gradient_rho = Array1::<f64>::zeros(input.derivative_hessians.len());
140 for (j, derivative) in input.derivative_hessians.iter().enumerate() {
141 let (solved, _) = crate::gpu::pirls_gpu::cholesky_solve_gpu(
142 input.penalized_hessian,
143 derivative.view(),
144 )?;
145 let mut trace = 0.0_f64;
146 for i in 0..p {
147 trace += solved[[i, i]];
148 }
149 gradient_rho[j] = 0.5 * trace;
150 }
151 Ok(RemlGpuEvidence {
152 logdet_hessian,
153 gradient_rho,
154 })
155 }
156}