1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum MixedPrecisionPolicy {
Off,
Screening,
Never,
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct GpuDispatchPolicy {
pub xtwx_n_min: usize,
pub xtwx_flops_min: usize,
pub xtwx_use_fused_below_p: usize,
pub gemm_min_flops: usize,
pub potrf_min_p: usize,
pub syevd_min_p: usize,
pub sparse_min_nnz: usize,
pub fused_kernel_min_n: usize,
pub keep_design_resident_min_bytes: usize,
pub prefer_gpu_factorization_min_p: usize,
pub row_kernel_min_n: usize,
pub mixed_precision: MixedPrecisionPolicy,
}
impl Default for GpuDispatchPolicy {
/// Auto-dispatch thresholds tuned for biobank-scale workloads:
///
/// * `gemm_min_flops = 1e8` — generic dense GEMM / GEMV is only worth a
/// device hop when the kernel is at least 10⁸ flops (e.g. a 320×320×320
/// product). Below that, the launch + PCIe round-trip dominates.
/// * `xtwx_n_min = 50_000`, `xtwx_use_fused_below_p = 256` —
/// `Xᵀ·diag(w)·X` requires both `n > 50k` rows AND `p > 256` columns
/// before the device wins; the row threshold ensures we stream-amortize
/// the weight broadcast and the column threshold rules out tiny GLM-style
/// designs that are bandwidth-bound on CPU already.
/// * `fused_kernel_min_n = 100_000` — the 2×2 joint-Hessian kernel only
/// runs on device when `n > 100k`; below that the CPU streaming pass
/// keeps the entire working set resident in L3.
/// * Cholesky / SyEVD live on device whenever the design is large enough
/// that the factorization itself dominates (`p ≥ 512` and `p ≥ 256`).
fn default() -> Self {
Self {
xtwx_n_min: 50_000,
xtwx_flops_min: 100_000_000,
xtwx_use_fused_below_p: 256,
gemm_min_flops: 100_000_000,
potrf_min_p: 512,
syevd_min_p: 256,
sparse_min_nnz: 1_000_000,
fused_kernel_min_n: 100_000,
keep_design_resident_min_bytes: 32 * 1024 * 1024,
prefer_gpu_factorization_min_p: 512,
row_kernel_min_n: 50_000,
mixed_precision: MixedPrecisionPolicy::Off,
}
}
}
impl GpuDispatchPolicy {
pub const fn dense_gemv_target_is_gpu(&self, n: usize, p: usize, resident: bool) -> bool {
resident || n.saturating_mul(p).saturating_mul(2) >= self.gemm_min_flops
}
pub const fn xtwx_target_is_gpu(&self, n: usize, p: usize, materialized: bool) -> bool {
materialized
&& n >= self.xtwx_n_min
&& n.saturating_mul(p).saturating_mul(p).saturating_mul(2) >= self.xtwx_flops_min
}
pub const fn potrf_target_is_gpu(&self, p: usize, h_resident: bool) -> bool {
h_resident && p >= self.potrf_min_p
}
}
/// Operation discriminator used by the dispatch decision API. Mirrors
/// `super::GpuOperation` at the policy layer.
#[derive(Clone, Copy, Debug)]
pub enum Operation {
Gemm,
Gemv,
XtDiagX,
XtDiagY,
}