use serde::{Deserialize, Serialize};
#[derive(Clone, Copy, Debug, Eq, PartialEq, Serialize, Deserialize)]
pub enum GpuMixedPrecisionPolicy {
Off,
Refinement,
Never,
}
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
pub struct GpuDispatchPolicy {
pub xtwx_n_min: usize,
pub xtwx_flops_min: usize,
pub xtwx_use_fused_below_p: usize,
pub gemm_min_flops: usize,
pub potrf_min_p: usize,
pub small_dense_batched_potrf_max_p: usize,
pub small_dense_batched_potrf_min_batch: usize,
pub syevd_min_p: usize,
pub sparse_min_nnz: usize,
pub fused_kernel_min_n: usize,
pub keep_design_resident_min_bytes: usize,
pub prefer_gpu_factorization_min_p: usize,
pub row_kernel_min_n: usize,
pub mixed_precision: GpuMixedPrecisionPolicy,
}
impl Default for GpuDispatchPolicy {
fn default() -> Self {
Self {
xtwx_n_min: 50_000,
xtwx_flops_min: 100_000_000,
xtwx_use_fused_below_p: 256,
gemm_min_flops: 100_000_000,
potrf_min_p: 512,
small_dense_batched_potrf_max_p: 32,
small_dense_batched_potrf_min_batch: 8,
syevd_min_p: 256,
sparse_min_nnz: 1_000_000,
fused_kernel_min_n: 100_000,
keep_design_resident_min_bytes: 32 * 1024 * 1024,
prefer_gpu_factorization_min_p: 512,
row_kernel_min_n: 50_000,
mixed_precision: GpuMixedPrecisionPolicy::Refinement,
}
}
}
impl GpuDispatchPolicy {
pub const REFINEMENT_MIN_P: usize = 64;
pub const REFINEMENT_MAX_STEPS: usize = 3;
pub const REFINEMENT_TOL: f64 = 1e-12;
#[inline]
pub const fn iterative_refinement_should_attempt(&self, p: usize) -> bool {
match self.mixed_precision {
GpuMixedPrecisionPolicy::Off | GpuMixedPrecisionPolicy::Never => false,
GpuMixedPrecisionPolicy::Refinement => p >= Self::REFINEMENT_MIN_P,
}
}
pub const fn dense_gemv_target_is_gpu(&self, n: usize, p: usize, resident: bool) -> bool {
resident || n.saturating_mul(p).saturating_mul(2) >= self.gemm_min_flops
}
pub const fn xtwx_target_is_gpu(&self, n: usize, p: usize, materialized: bool) -> bool {
materialized && n > 0 && p > 0 && self.xtwx_flops(n, p) >= self.dense_reduction_flops_min()
}
pub const fn xtwy_target_is_gpu(
&self,
n: usize,
px: usize,
q: usize,
materialized: bool,
) -> bool {
materialized
&& n > 0
&& px > 0
&& q > 0
&& self.xtwy_flops(n, px, q) >= self.dense_reduction_flops_min()
}
pub const fn potrf_target_is_gpu(&self, p: usize, h_resident: bool) -> bool {
h_resident && p >= self.potrf_min_p
}
pub const fn dense_hessian_work_target_is_gpu(&self, n: usize, p: usize) -> bool {
n > 0
&& p >= Self::DEVICE_LOOP_MIN_P
&& self.xtwx_flops(n, p) >= self.dense_reduction_flops_min()
}
const fn dense_reduction_flops_min(&self) -> u128 {
if self.xtwx_flops_min < self.gemm_min_flops {
self.xtwx_flops_min as u128
} else {
self.gemm_min_flops as u128
}
}
const fn xtwx_flops(&self, n: usize, p: usize) -> u128 {
2u128 * (n as u128) * (p as u128) * (p as u128)
}
const fn xtwy_flops(&self, n: usize, px: usize, q: usize) -> u128 {
2u128 * (n as u128) * (px as u128) * (q as u128)
}
pub const MATVEC_OFFLOAD_FLOPS_MIN: u128 = 10_000_000;
pub const MATVEC_OFFLOAD_MIN_CG_ITERS: usize = 8;
const fn admission_work_lower_bound(n: usize, k: usize, d: usize) -> u128 {
let n = n as u128;
let k = k as u128;
let d = d as u128;
n.saturating_mul(
2u128
.saturating_mul(d)
.saturating_mul(k)
.saturating_add(d * d),
)
}
pub const fn reduced_schur_matvec_should_offload(
&self,
n: usize,
k: usize,
d: usize,
cg_iters: usize,
) -> bool {
if n == 0 || k == 0 || d == 0 || cg_iters == 0 {
return false;
}
if k < Self::DEVICE_LOOP_MIN_P {
return false;
}
let per_apply = Self::admission_work_lower_bound(n, k, d);
let total = per_apply.saturating_mul(cg_iters as u128);
total >= Self::MATVEC_OFFLOAD_FLOPS_MIN
}
}
pub const GPU_THROUGHPUT_TARGET_ROWS_PER_SEC: f64 = 100_000.0;
#[derive(Clone, Copy, Debug, PartialEq)]
pub struct GpuThroughputVerdict {
pub measured_rows_per_sec: f64,
pub target_rows_per_sec: f64,
pub fraction_of_target: f64,
pub meets_target: bool,
}
impl GpuThroughputVerdict {
#[inline]
pub fn from_measurement(measured_rows_per_sec: f64) -> Self {
Self::from_measurement_against(measured_rows_per_sec, GPU_THROUGHPUT_TARGET_ROWS_PER_SEC)
}
#[inline]
pub fn from_measurement_against(measured_rows_per_sec: f64, target_rows_per_sec: f64) -> Self {
let usable = measured_rows_per_sec.is_finite() && measured_rows_per_sec > 0.0;
let fraction_of_target = if usable && target_rows_per_sec > 0.0 {
measured_rows_per_sec / target_rows_per_sec
} else {
0.0
};
Self {
measured_rows_per_sec,
target_rows_per_sec,
fraction_of_target,
meets_target: usable && measured_rows_per_sec >= target_rows_per_sec,
}
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum EncodeDecisionBlocked {
NoDevice,
NoDeviceEncodeKernel,
DeviceNotEngaged,
}
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum EncodeDeploymentDecision {
Met {
measured_rows_per_sec: f64,
target_rows_per_sec: f64,
},
Unmet {
measured_rows_per_sec: f64,
target_rows_per_sec: f64,
},
Undetermined {
reason: EncodeDecisionBlocked,
},
}
impl EncodeDeploymentDecision {
#[must_use]
pub fn from_device_measurement(engaged: bool, measured_rows_per_sec: f64) -> Self {
Self::from_device_measurement_against(
engaged,
measured_rows_per_sec,
GPU_THROUGHPUT_TARGET_ROWS_PER_SEC,
)
}
#[must_use]
pub fn from_device_measurement_against(
engaged: bool,
measured_rows_per_sec: f64,
target_rows_per_sec: f64,
) -> Self {
let usable = measured_rows_per_sec.is_finite() && measured_rows_per_sec > 0.0;
if !engaged || !usable {
return Self::Undetermined {
reason: EncodeDecisionBlocked::DeviceNotEngaged,
};
}
if measured_rows_per_sec >= target_rows_per_sec {
Self::Met {
measured_rows_per_sec,
target_rows_per_sec,
}
} else {
Self::Unmet {
measured_rows_per_sec,
target_rows_per_sec,
}
}
}
#[must_use]
pub fn blocked(reason: EncodeDecisionBlocked) -> Self {
Self::Undetermined { reason }
}
#[must_use]
pub fn surrogate_unneeded(&self) -> bool {
matches!(self, Self::Met { .. })
}
#[must_use]
pub fn surrogate_justified(&self) -> bool {
matches!(self, Self::Unmet { .. })
}
#[must_use]
pub fn is_undetermined(&self) -> bool {
matches!(self, Self::Undetermined { .. })
}
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum PirlsLoopFamilyKind {
BernoulliLogit,
BernoulliProbit,
BernoulliCLogLog,
PoissonLog,
GaussianIdentity,
GammaLog,
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum PirlsLoopCurvatureKind {
Fisher,
Observed,
}
#[derive(Clone, Copy, Debug)]
pub struct RemlOuterAdmission {
pub n: usize,
pub p: usize,
pub num_rho: usize,
pub family: Option<PirlsLoopFamilyKind>,
pub curvature: PirlsLoopCurvatureKind,
pub gpu_available: bool,
}
#[derive(Clone, Copy, Debug)]
pub struct PirlsLoopAdmission {
pub n: usize,
pub p: usize,
pub family: Option<PirlsLoopFamilyKind>,
pub curvature: PirlsLoopCurvatureKind,
pub gpu_available: bool,
}
impl GpuDispatchPolicy {
pub const DEVICE_LOOP_MIN_P: usize = 32;
pub const fn should_use_gpu_pirls_loop(&self, adm: PirlsLoopAdmission) -> bool {
if !adm.gpu_available {
return false;
}
if !self.dense_hessian_work_target_is_gpu(adm.n, adm.p) {
return false;
}
match adm.family {
Some(_) => true,
None => false,
}
}
pub const fn should_run_reml_outer_on_device(&self, adm: RemlOuterAdmission) -> bool {
if !adm.gpu_available {
return false;
}
if !self.dense_hessian_work_target_is_gpu(adm.n, adm.p) {
return false;
}
if adm.num_rho < 2 {
return false;
}
match adm.family {
Some(_) => true,
None => false,
}
}
}
#[cfg(test)]
mod refinement_policy_tests {
use super::*;
#[test]
fn refinement_policy_admits_large_p() {
let pol = GpuDispatchPolicy::default();
assert!(pol.iterative_refinement_should_attempt(512));
assert!(pol.iterative_refinement_should_attempt(GpuDispatchPolicy::REFINEMENT_MIN_P));
}
#[test]
fn refinement_policy_rejects_small_p() {
let pol = GpuDispatchPolicy::default();
assert!(!pol.iterative_refinement_should_attempt(GpuDispatchPolicy::REFINEMENT_MIN_P - 1));
assert!(!pol.iterative_refinement_should_attempt(0));
}
#[test]
fn off_policy_never_attempts_refinement() {
let pol = GpuDispatchPolicy {
mixed_precision: GpuMixedPrecisionPolicy::Off,
..Default::default()
};
assert!(!pol.iterative_refinement_should_attempt(1024));
}
#[test]
fn never_policy_never_attempts_refinement() {
let pol = GpuDispatchPolicy {
mixed_precision: GpuMixedPrecisionPolicy::Never,
..Default::default()
};
assert!(!pol.iterative_refinement_should_attempt(1024));
}
}
#[cfg(test)]
mod reduced_schur_matvec_offload_tests {
use super::*;
#[test]
fn admits_llm_sae_matvec_shape() {
let pol = GpuDispatchPolicy::default();
assert!(pol.reduced_schur_matvec_should_offload(
2_000,
2_048,
8,
GpuDispatchPolicy::MATVEC_OFFLOAD_MIN_CG_ITERS,
));
assert!(!pol.dense_hessian_work_target_is_gpu(2_000, 8));
}
#[test]
fn admits_llm_shape_with_one_cg_iter() {
let pol = GpuDispatchPolicy::default();
assert!(pol.reduced_schur_matvec_should_offload(2_000, 2_048, 8, 1));
}
#[test]
fn rejects_tiny_shape_where_transfer_dominates() {
let pol = GpuDispatchPolicy::default();
assert!(!pol.reduced_schur_matvec_should_offload(
30,
8,
2,
GpuDispatchPolicy::MATVEC_OFFLOAD_MIN_CG_ITERS,
));
assert!(!pol.reduced_schur_matvec_should_offload(300, 8, 4, 16));
}
#[test]
fn rejects_narrow_border_even_with_huge_row_count() {
let pol = GpuDispatchPolicy::default();
let narrow = GpuDispatchPolicy::DEVICE_LOOP_MIN_P - 1;
assert!(!pol.reduced_schur_matvec_should_offload(1_000_000, narrow, 64, 64));
}
#[test]
fn rejects_degenerate_dimensions() {
let pol = GpuDispatchPolicy::default();
assert!(!pol.reduced_schur_matvec_should_offload(0, 2_048, 8, 8));
assert!(!pol.reduced_schur_matvec_should_offload(2_000, 0, 8, 8));
assert!(!pol.reduced_schur_matvec_should_offload(2_000, 2_048, 0, 8));
assert!(!pol.reduced_schur_matvec_should_offload(2_000, 2_048, 8, 0));
}
#[test]
fn monotone_in_cg_iters() {
let pol = GpuDispatchPolicy::default();
let (n, k, d) = (200usize, GpuDispatchPolicy::DEVICE_LOOP_MIN_P, 4usize);
assert!(!pol.reduced_schur_matvec_should_offload(n, k, d, 1));
assert!(pol.reduced_schur_matvec_should_offload(n, k, d, 1_000));
assert!(pol.reduced_schur_matvec_should_offload(n, k, d, 5_000));
}
#[test]
fn admission_lower_bound_undercounts_actual_work() {
for &(n, k, d) in &[
(2_000usize, 2_048usize, 8usize),
(200, GpuDispatchPolicy::DEVICE_LOOP_MIN_P, 4),
(1, 1, 1),
] {
let lower = GpuDispatchPolicy::admission_work_lower_bound(n, k, d);
let actual = (n as u128) * (4 * (d as u128) * (k as u128) + (d as u128) * (d as u128));
assert!(
lower < actual,
"admission lower bound {lower} must undercount actual work {actual} for ({n},{k},{d})"
);
}
}
}
#[cfg(test)]
mod encode_deployment_decision_tests {
use super::*;
#[test]
fn cpu_rate_can_never_meet_or_refute_the_target() {
let cpu_only = EncodeDeploymentDecision::blocked(EncodeDecisionBlocked::NoDevice);
assert!(cpu_only.is_undetermined());
assert!(!cpu_only.surrogate_unneeded());
assert!(!cpu_only.surrogate_justified());
let false_routed = EncodeDeploymentDecision::from_device_measurement(false, 1.0e9);
assert!(false_routed.is_undetermined());
assert!(!false_routed.surrogate_unneeded());
}
#[test]
fn engaged_measurement_decides_by_the_number() {
let target = GPU_THROUGHPUT_TARGET_ROWS_PER_SEC;
let met = EncodeDeploymentDecision::from_device_measurement(true, target * 2.0);
assert!(matches!(met, EncodeDeploymentDecision::Met { .. }));
assert!(met.surrogate_unneeded());
assert!(!met.surrogate_justified());
assert!(!met.is_undetermined());
let unmet = EncodeDeploymentDecision::from_device_measurement(true, target * 0.25);
assert!(matches!(unmet, EncodeDeploymentDecision::Unmet { .. }));
assert!(unmet.surrogate_justified());
assert!(!unmet.surrogate_unneeded());
let boundary = EncodeDeploymentDecision::from_device_measurement(true, target);
assert!(boundary.surrogate_unneeded());
}
#[test]
fn engaged_but_non_usable_rate_is_undetermined_not_a_pass() {
for bad in [0.0, -1.0, f64::NAN, f64::INFINITY] {
let d = EncodeDeploymentDecision::from_device_measurement(true, bad);
assert!(
d.is_undetermined(),
"an engaged-but-unusable rate {bad} must be Undetermined, not a decision"
);
assert!(!d.surrogate_unneeded());
assert!(!d.surrogate_justified());
}
}
#[test]
fn blocked_reasons_are_all_undetermined() {
for reason in [
EncodeDecisionBlocked::NoDevice,
EncodeDecisionBlocked::NoDeviceEncodeKernel,
EncodeDecisionBlocked::DeviceNotEngaged,
] {
let d = EncodeDeploymentDecision::blocked(reason);
assert!(d.is_undetermined());
assert!(!d.surrogate_unneeded());
assert!(!d.surrogate_justified());
}
}
}