1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum MixedPrecisionPolicy {
/// Always use fp64 factorization; no refinement attempted.
Off,
/// Attempt fp32 Cholesky factorization followed by up to
/// `REFINEMENT_MAX_STEPS` fp64-residual refinement steps. Policy admits
/// the attempt only when `p ≥ REFINEMENT_MIN_P` (so that the fp64 GEMV
/// overhead is amortized) and the measured residual drops monotonically.
/// Falls back to fp64 factorization automatically when the residual does
/// not decrease (κ(A)·u ≥ 1 regime) or when the fp32 POTRF itself fails.
Refinement,
/// Always use fp64 factorization; equivalent to `Off` but signals that
/// an explicit policy decision was taken.
Never,
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct GpuDispatchPolicy {
pub xtwx_n_min: usize,
pub xtwx_flops_min: usize,
pub xtwx_use_fused_below_p: usize,
pub gemm_min_flops: usize,
pub potrf_min_p: usize,
pub small_dense_batched_potrf_max_p: usize,
pub small_dense_batched_potrf_min_batch: usize,
pub syevd_min_p: usize,
pub sparse_min_nnz: usize,
pub fused_kernel_min_n: usize,
pub keep_design_resident_min_bytes: usize,
pub prefer_gpu_factorization_min_p: usize,
pub row_kernel_min_n: usize,
pub mixed_precision: MixedPrecisionPolicy,
}
impl Default for GpuDispatchPolicy {
/// Auto-dispatch thresholds tuned for biobank-scale workloads:
///
/// * `gemm_min_flops = 1e8` — generic dense GEMM / GEMV is only worth a
/// device hop when the kernel is at least 10⁸ flops (e.g. a 320×320×320
/// product). Below that, the launch + PCIe round-trip dominates.
/// * `xtwx_n_min = 50_000`, `xtwx_use_fused_below_p = 256` —
/// `Xᵀ·diag(w)·X` requires both `n > 50k` rows AND `p > 256` columns
/// before the device wins; the row threshold ensures we stream-amortize
/// the weight broadcast and the column threshold rules out tiny GLM-style
/// designs that are bandwidth-bound on CPU already.
/// * `fused_kernel_min_n = 100_000` — the 2×2 joint-Hessian kernel only
/// runs on device when `n > 100k`; below that the CPU streaming pass
/// keeps the entire working set resident in L3.
/// * Cholesky / SyEVD live on device whenever the design is large enough
/// that the factorization itself dominates (`p ≥ 512` and `p ≥ 256`).
fn default() -> Self {
Self {
xtwx_n_min: 50_000,
xtwx_flops_min: 100_000_000,
xtwx_use_fused_below_p: 256,
gemm_min_flops: 100_000_000,
potrf_min_p: 512,
small_dense_batched_potrf_max_p: 32,
small_dense_batched_potrf_min_batch: 8,
syevd_min_p: 256,
sparse_min_nnz: 1_000_000,
fused_kernel_min_n: 100_000,
keep_design_resident_min_bytes: 32 * 1024 * 1024,
prefer_gpu_factorization_min_p: 512,
row_kernel_min_n: 50_000,
mixed_precision: MixedPrecisionPolicy::Refinement,
}
}
}
impl GpuDispatchPolicy {
/// Minimum problem dimension for the fp32+refinement path.
///
/// Below this threshold the fp64 GEMV needed for the residual check costs
/// more than the savings from fp32 factorization. The threshold is set so
/// that a single `p × p` DGEMV (2p² flops) is at least 10× cheaper than
/// the `p³/3` POTRF (i.e. p ≥ 64) while still leaving margin for the
/// POTRF/POTRS launches. In practice `p ≥ 64` matches the existing
/// `potrf_min_p = 512` floor for GPU dispatch, so the refinement path only
/// activates when the GPU factorization path is already chosen.
pub const REFINEMENT_MIN_P: usize = 64;
/// Maximum number of fp32-correction steps per solve.
///
/// Two steps suffice for κ(A) ≤ 10⁵ at fp32 (u ≈ 6 × 10⁻⁸): after step
/// 1 the error is O(κ u)² ≈ 10⁻⁶, after step 2 it is O(κ u)⁴ ≈ 10⁻¹²,
/// which is well within the fp64 unit roundoff of 10⁻¹⁶ × κ. A cap of 3
/// is used defensively.
pub const REFINEMENT_MAX_STEPS: usize = 3;
/// Relative residual tolerance for declaring convergence.
///
/// `‖r‖ / ‖b‖ ≤ tol` is considered a converged solve. 10⁻¹² is two
/// orders of magnitude above the fp64 machine epsilon times a moderate
/// condition number, leaving the policy conservative.
pub const REFINEMENT_TOL: f64 = 1e-12;
/// Return `true` when the policy and problem size together suggest that
/// attempting fp32 factorization + iterative refinement will be profitable.
///
/// The predicate is conservative:
/// * `MixedPrecisionPolicy::Off` or `Never` → always `false`.
/// * `Refinement` with `p < REFINEMENT_MIN_P` → `false` (GEMV overhead
/// not amortised by fp32 POTRF savings below this threshold).
/// * Otherwise `true`; the caller still falls back to fp64 factorization
/// when the runtime fp32 POTRF fails or when the measured residual is
/// non-monotone.
#[inline]
pub const fn iterative_refinement_should_attempt(&self, p: usize) -> bool {
match self.mixed_precision {
MixedPrecisionPolicy::Off | MixedPrecisionPolicy::Never => false,
MixedPrecisionPolicy::Refinement => p >= Self::REFINEMENT_MIN_P,
}
}
pub const fn dense_gemv_target_is_gpu(&self, n: usize, p: usize, resident: bool) -> bool {
resident || n.saturating_mul(p).saturating_mul(2) >= self.gemm_min_flops
}
pub const fn xtwx_target_is_gpu(&self, n: usize, p: usize, materialized: bool) -> bool {
materialized
&& n >= self.xtwx_n_min
&& n.saturating_mul(p).saturating_mul(p).saturating_mul(2) >= self.xtwx_flops_min
}
pub const fn potrf_target_is_gpu(&self, p: usize, h_resident: bool) -> bool {
h_resident && p >= self.potrf_min_p
}
}
/// Operation discriminator used by the dispatch decision API. Mirrors
/// `super::GpuOperation` at the policy layer.
#[derive(Clone, Copy, Debug)]
pub enum Operation {
Gemm,
Gemv,
XtDiagX,
XtDiagY,
}
/// Which `(response, link)` family the Stage 3.3 device-resident PIRLS loop
/// can evaluate without going through the Level-B raw-body NVRTC path.
///
/// Mirrors `PirlsRowFamily::ALL` at the policy layer so the predicate stays
/// linkable from the CPU PIRLS entry without dragging a Linux-only enum into
/// every host compilation unit.
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum PirlsLoopFamilyKind {
BernoulliLogit,
BernoulliProbit,
BernoulliCLogLog,
PoissonLog,
GaussianIdentity,
GammaLog,
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum PirlsLoopCurvatureKind {
Fisher,
Observed,
}
/// Inputs to [`should_run_reml_outer_on_device`]. The admission predicate
/// for routing the *outer* REML BFGS-over-ρ loop onto a fully device-resident
/// driver (rather than the host orchestrator that hops out per step).
///
/// Fields are intentionally lifted from data the CPU REML entry has on hand
/// before it touches the seed generator or the inner P-IRLS loop, so the
/// admission check is allocation-free and can short-circuit before any
/// device call.
#[derive(Clone, Copy, Debug)]
pub struct RemlOuterAdmission {
/// Active design rows (post-transform).
pub n: usize,
/// Active design columns / penalised-Hessian dimension.
pub p: usize,
/// Number of smoothing parameters ρ the outer BFGS optimises over.
pub num_rho: usize,
/// Inner family / link pair the device-resident PIRLS loop can evaluate.
/// `None` means the family does not map onto the six JIT-cached row
/// kernels — the outer loop must stay on the host orchestrator because
/// the inner step would already hop out anyway.
pub family: Option<PirlsLoopFamilyKind>,
/// Curvature surface the inner loop will use; tied to `family` via
/// `pirls_loop_curvature_for`.
pub curvature: PirlsLoopCurvatureKind,
/// True when the CUDA runtime is initialised on this host.
pub gpu_available: bool,
}
/// Inputs to [`should_use_gpu_pirls_loop`]. Each field comes from data the
/// CPU PIRLS entry has on hand before it touches the eigendecomposition
/// engine, so the admission check itself is allocation-free and can short-
/// circuit before any heavy work happens.
#[derive(Clone, Copy, Debug)]
pub struct PirlsLoopAdmission {
/// Number of rows in the active (post-transform) design matrix.
pub n: usize,
/// Number of columns in the active design (i.e. `p` of `Xᵀ X`).
pub p: usize,
/// `Some(_)` when the inner family maps onto one of the six JIT-cached
/// `PirlsRowFamily` variants; `None` for custom families that still
/// require Stage 6 Level B and have not yet been admitted here.
pub family: Option<PirlsLoopFamilyKind>,
/// Curvature surface the inner loop will use; the GPU loop has Fisher +
/// Observed kernels, anything else (e.g. expected-projection surrogates)
/// is not admitted.
pub curvature: PirlsLoopCurvatureKind,
/// True when the CUDA runtime is initialised on this host (i.e.
/// `GpuRuntime::global().is_some()`).
pub gpu_available: bool,
}
impl GpuDispatchPolicy {
/// Conservative admission predicate for routing
/// `fit_model_for_fixed_rho_with_adaptive_kkt` through the Stage 3.3
/// device-resident PIRLS loop instead of the CPU LM loop.
///
/// The thresholds (`n ≥ 50_000`, `p ≥ 32`) are deliberately well above
/// the matrix-size where a single PIRLS iter's `XᵀWX + Cholesky` would
/// be PCIe-bandwidth-bound. Smaller fits stay on the CPU LM loop where
/// the full `PirlsResult` surface (firth, EDF, per-row weights, …) is
/// already populated as a free side-effect of the iteration.
pub const fn should_use_gpu_pirls_loop(&self, adm: PirlsLoopAdmission) -> bool {
if !adm.gpu_available {
return false;
}
if adm.n < self.row_kernel_min_n {
return false;
}
if adm.p < 32 {
return false;
}
match adm.family {
Some(_) => true,
None => false,
}
}
/// Admission predicate for routing the outer REML BFGS-over-ρ loop onto
/// a device-resident driver that keeps the BFGS state (ρ, gradient,
/// Hessian approx) on-device and only downloads the per-step scalar
/// metrics (objective value, gradient norm, convergence flag).
///
/// The thresholds piggyback on the existing inner-PIRLS admission floor
/// (`n ≥ row_kernel_min_n`, `p ≥ 32`) because the device-resident outer
/// loop calls `pirls_loop_on_stream` per step and must not pay the host
/// hop for small fits the inner loop would have rejected anyway. The
/// `num_rho ≥ 2` floor rules out the trivial single-smoother case where
/// host orchestration is already negligible and the device BFGS state
/// (one length-`num_rho` gradient + a `num_rho × num_rho` Hessian
/// approx) collapses to a couple of scalars not worth keeping on device.
pub const fn should_run_reml_outer_on_device(&self, adm: RemlOuterAdmission) -> bool {
if !adm.gpu_available {
return false;
}
if adm.n < self.row_kernel_min_n {
return false;
}
if adm.p < 32 {
return false;
}
if adm.num_rho < 2 {
return false;
}
match adm.family {
Some(_) => true,
None => false,
}
}
}
#[cfg(test)]
mod refinement_policy_tests {
use super::*;
#[test]
fn refinement_policy_admits_large_p() {
let pol = GpuDispatchPolicy::default();
// Default policy is Refinement; large p should be admitted.
assert!(pol.iterative_refinement_should_attempt(512));
assert!(pol.iterative_refinement_should_attempt(GpuDispatchPolicy::REFINEMENT_MIN_P));
}
#[test]
fn refinement_policy_rejects_small_p() {
let pol = GpuDispatchPolicy::default();
assert!(!pol.iterative_refinement_should_attempt(GpuDispatchPolicy::REFINEMENT_MIN_P - 1));
assert!(!pol.iterative_refinement_should_attempt(0));
}
#[test]
fn off_policy_never_attempts_refinement() {
let pol = GpuDispatchPolicy {
mixed_precision: MixedPrecisionPolicy::Off,
..Default::default()
};
assert!(!pol.iterative_refinement_should_attempt(1024));
}
#[test]
fn never_policy_never_attempts_refinement() {
let pol = GpuDispatchPolicy {
mixed_precision: MixedPrecisionPolicy::Never,
..Default::default()
};
assert!(!pol.iterative_refinement_should_attempt(1024));
}
}