use serde::{Deserialize, Serialize};
#[derive(Clone, Copy, Debug, Eq, PartialEq, Serialize, Deserialize)]
pub enum GpuMixedPrecisionPolicy {
Off,
Refinement,
Never,
}
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
pub struct GpuDispatchPolicy {
pub xtwx_n_min: usize,
pub xtwx_flops_min: usize,
pub xtwx_use_fused_below_p: usize,
pub gemm_min_flops: usize,
pub potrf_min_p: usize,
pub small_dense_batched_potrf_max_p: usize,
pub small_dense_batched_potrf_min_batch: usize,
pub syevd_min_p: usize,
pub sparse_min_nnz: usize,
pub fused_kernel_min_n: usize,
pub keep_design_resident_min_bytes: usize,
pub prefer_gpu_factorization_min_p: usize,
pub row_kernel_min_n: usize,
pub mixed_precision: GpuMixedPrecisionPolicy,
}
impl Default for GpuDispatchPolicy {
fn default() -> Self {
Self {
xtwx_n_min: 50_000,
xtwx_flops_min: 100_000_000,
xtwx_use_fused_below_p: 256,
gemm_min_flops: 100_000_000,
potrf_min_p: 512,
small_dense_batched_potrf_max_p: 32,
small_dense_batched_potrf_min_batch: 8,
syevd_min_p: 256,
sparse_min_nnz: 1_000_000,
fused_kernel_min_n: 100_000,
keep_design_resident_min_bytes: 32 * 1024 * 1024,
prefer_gpu_factorization_min_p: 512,
row_kernel_min_n: 50_000,
mixed_precision: GpuMixedPrecisionPolicy::Refinement,
}
}
}
impl GpuDispatchPolicy {
pub const REFINEMENT_MIN_P: usize = 64;
pub const REFINEMENT_MAX_STEPS: usize = 3;
pub const REFINEMENT_TOL: f64 = 1e-12;
#[inline]
pub const fn iterative_refinement_should_attempt(&self, p: usize) -> bool {
match self.mixed_precision {
GpuMixedPrecisionPolicy::Off | GpuMixedPrecisionPolicy::Never => false,
GpuMixedPrecisionPolicy::Refinement => p >= Self::REFINEMENT_MIN_P,
}
}
pub const fn dense_gemv_target_is_gpu(&self, n: usize, p: usize, resident: bool) -> bool {
resident || n.saturating_mul(p).saturating_mul(2) >= self.gemm_min_flops
}
pub const fn xtwx_target_is_gpu(&self, n: usize, p: usize, materialized: bool) -> bool {
materialized && n > 0 && p > 0 && self.xtwx_flops(n, p) >= self.dense_reduction_flops_min()
}
pub const fn xtwy_target_is_gpu(
&self,
n: usize,
px: usize,
q: usize,
materialized: bool,
) -> bool {
materialized
&& n > 0
&& px > 0
&& q > 0
&& self.xtwy_flops(n, px, q) >= self.dense_reduction_flops_min()
}
pub const fn potrf_target_is_gpu(&self, p: usize, h_resident: bool) -> bool {
h_resident && p >= self.potrf_min_p
}
pub const fn dense_hessian_work_target_is_gpu(&self, n: usize, p: usize) -> bool {
n > 0
&& p >= Self::DEVICE_LOOP_MIN_P
&& self.xtwx_flops(n, p) >= self.dense_reduction_flops_min()
}
const fn dense_reduction_flops_min(&self) -> u128 {
if self.xtwx_flops_min < self.gemm_min_flops {
self.xtwx_flops_min as u128
} else {
self.gemm_min_flops as u128
}
}
const fn xtwx_flops(&self, n: usize, p: usize) -> u128 {
2u128 * (n as u128) * (p as u128) * (p as u128)
}
const fn xtwy_flops(&self, n: usize, px: usize, q: usize) -> u128 {
2u128 * (n as u128) * (px as u128) * (q as u128)
}
pub const MATVEC_OFFLOAD_FLOPS_MIN: u128 = 10_000_000;
pub const THIN_CURVE_MATVEC_OFFLOAD_FLOPS_MIN: u128 = 1_000_000;
pub const MATVEC_OFFLOAD_MIN_CG_ITERS: usize = 8;
const fn admission_work_lower_bound(n: usize, k: usize, d: usize) -> u128 {
let n = n as u128;
let k = k as u128;
let d = d as u128;
n.saturating_mul(
2u128
.saturating_mul(d)
.saturating_mul(k)
.saturating_add(d * d),
)
}
pub const fn reduced_schur_matvec_should_offload(
&self,
n: usize,
k: usize,
d: usize,
cg_iters: usize,
) -> bool {
if n == 0 || k == 0 || d == 0 || cg_iters == 0 {
return false;
}
if k < Self::DEVICE_LOOP_MIN_P {
return false;
}
let per_apply = Self::admission_work_lower_bound(n, k, d);
let total = per_apply.saturating_mul(cg_iters as u128);
let floor = if d == 1 {
Self::THIN_CURVE_MATVEC_OFFLOAD_FLOPS_MIN
} else {
Self::MATVEC_OFFLOAD_FLOPS_MIN
};
total >= floor
}
}
#[derive(Clone, Copy, Debug, Eq, PartialEq, Serialize, Deserialize)]
pub enum ArrowBorderStrategy {
DenseDirect,
ReducedIterative,
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub struct ArrowBorderSolvePlan {
pub n: usize,
pub k: usize,
pub d: usize,
pub cg_iters: usize,
pub data_fit_rank: usize,
pub dense_border_rank_deficient: bool,
pub dense_direct_flops: u128,
pub reduced_iterative_flops: u128,
pub recommended: ArrowBorderStrategy,
pub device_favorable: bool,
}
impl GpuDispatchPolicy {
const fn dense_schur_assembly_flops(n: usize, k: usize, d: usize) -> u128 {
2u128
.saturating_mul(n as u128)
.saturating_mul(d as u128)
.saturating_mul((k as u128).saturating_mul(k as u128))
}
const fn dense_border_cholesky_flops(k: usize) -> u128 {
let k = k as u128;
k.saturating_mul(k).saturating_mul(k) / 3
}
const fn reduced_iterative_flops(n: usize, k: usize, d: usize, cg_iters: usize) -> u128 {
let n = n as u128;
let k = k as u128;
let d = d as u128;
let per_apply = n.saturating_mul(
4u128
.saturating_mul(d)
.saturating_mul(k)
.saturating_add(d.saturating_mul(d)),
);
per_apply.saturating_mul(cg_iters as u128)
}
pub fn arrow_border_solve_plan(
&self,
n: usize,
k: usize,
d: usize,
cg_iters: usize,
) -> ArrowBorderSolvePlan {
if n == 0 || k == 0 || d == 0 {
return ArrowBorderSolvePlan {
n,
k,
d,
cg_iters,
data_fit_rank: 0,
dense_border_rank_deficient: false,
dense_direct_flops: 0,
reduced_iterative_flops: 0,
recommended: ArrowBorderStrategy::DenseDirect,
device_favorable: false,
};
}
let assembly = Self::dense_schur_assembly_flops(n, k, d);
let border_chol = Self::dense_border_cholesky_flops(k);
let dense_direct_flops = assembly.saturating_add(border_chol);
let iters = if cg_iters == 0 { 1 } else { cg_iters };
let reduced_iterative_flops = Self::reduced_iterative_flops(n, k, d, iters);
let data_fit_rank = (n.saturating_mul(d)).min(k);
let dense_border_rank_deficient = n.saturating_mul(d) < k;
let recommended = if dense_direct_flops > reduced_iterative_flops {
ArrowBorderStrategy::ReducedIterative
} else {
ArrowBorderStrategy::DenseDirect
};
let device_favorable = match recommended {
ArrowBorderStrategy::ReducedIterative => {
self.reduced_schur_matvec_should_offload(n, k, d, iters)
}
ArrowBorderStrategy::DenseDirect => {
assembly >= border_chol
&& dense_direct_flops >= self.dense_reduction_flops_min()
}
};
ArrowBorderSolvePlan {
n,
k,
d,
cg_iters: iters,
data_fit_rank,
dense_border_rank_deficient,
dense_direct_flops,
reduced_iterative_flops,
recommended,
device_favorable,
}
}
}
pub const GPU_THROUGHPUT_TARGET_ROWS_PER_SEC: f64 = 100_000.0;
#[derive(Clone, Copy, Debug, PartialEq)]
pub struct GpuThroughputVerdict {
pub measured_rows_per_sec: f64,
pub target_rows_per_sec: f64,
pub fraction_of_target: f64,
pub meets_target: bool,
}
impl GpuThroughputVerdict {
#[inline]
pub fn from_measurement(measured_rows_per_sec: f64) -> Self {
Self::from_measurement_against(measured_rows_per_sec, GPU_THROUGHPUT_TARGET_ROWS_PER_SEC)
}
#[inline]
pub fn from_measurement_against(measured_rows_per_sec: f64, target_rows_per_sec: f64) -> Self {
let usable = measured_rows_per_sec.is_finite() && measured_rows_per_sec > 0.0;
let fraction_of_target = if usable && target_rows_per_sec > 0.0 {
measured_rows_per_sec / target_rows_per_sec
} else {
0.0
};
Self {
measured_rows_per_sec,
target_rows_per_sec,
fraction_of_target,
meets_target: usable && measured_rows_per_sec >= target_rows_per_sec,
}
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum EncodeDecisionBlocked {
NoDevice,
NoDeviceEncodeKernel,
DeviceNotEngaged,
}
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum EncodeDeploymentDecision {
Met {
measured_rows_per_sec: f64,
target_rows_per_sec: f64,
},
Unmet {
measured_rows_per_sec: f64,
target_rows_per_sec: f64,
},
Undetermined {
reason: EncodeDecisionBlocked,
},
}
impl EncodeDeploymentDecision {
#[must_use]
pub fn from_device_measurement(engaged: bool, measured_rows_per_sec: f64) -> Self {
Self::from_device_measurement_against(
engaged,
measured_rows_per_sec,
GPU_THROUGHPUT_TARGET_ROWS_PER_SEC,
)
}
#[must_use]
pub fn from_device_measurement_against(
engaged: bool,
measured_rows_per_sec: f64,
target_rows_per_sec: f64,
) -> Self {
let usable = measured_rows_per_sec.is_finite() && measured_rows_per_sec > 0.0;
if !engaged || !usable {
return Self::Undetermined {
reason: EncodeDecisionBlocked::DeviceNotEngaged,
};
}
if measured_rows_per_sec >= target_rows_per_sec {
Self::Met {
measured_rows_per_sec,
target_rows_per_sec,
}
} else {
Self::Unmet {
measured_rows_per_sec,
target_rows_per_sec,
}
}
}
#[must_use]
pub fn blocked(reason: EncodeDecisionBlocked) -> Self {
Self::Undetermined { reason }
}
#[must_use]
pub fn surrogate_unneeded(&self) -> bool {
matches!(self, Self::Met { .. })
}
#[must_use]
pub fn surrogate_justified(&self) -> bool {
matches!(self, Self::Unmet { .. })
}
#[must_use]
pub fn is_undetermined(&self) -> bool {
matches!(self, Self::Undetermined { .. })
}
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum PirlsLoopFamilyKind {
BernoulliLogit,
BernoulliProbit,
BernoulliCLogLog,
PoissonLog,
GaussianIdentity,
GammaLog,
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum PirlsLoopCurvatureKind {
Fisher,
Observed,
}
#[derive(Clone, Copy, Debug)]
pub struct RemlOuterAdmission {
pub n: usize,
pub p: usize,
pub num_rho: usize,
pub family: Option<PirlsLoopFamilyKind>,
pub curvature: PirlsLoopCurvatureKind,
pub gpu_available: bool,
}
#[derive(Clone, Copy, Debug)]
pub struct PirlsLoopAdmission {
pub n: usize,
pub p: usize,
pub family: Option<PirlsLoopFamilyKind>,
pub curvature: PirlsLoopCurvatureKind,
pub gpu_available: bool,
}
impl GpuDispatchPolicy {
pub const DEVICE_LOOP_MIN_P: usize = 32;
pub const fn should_use_gpu_pirls_loop(&self, adm: PirlsLoopAdmission) -> bool {
if !adm.gpu_available {
return false;
}
if !self.dense_hessian_work_target_is_gpu(adm.n, adm.p) {
return false;
}
match adm.family {
Some(_) => true,
None => false,
}
}
pub const fn should_run_reml_outer_on_device(&self, adm: RemlOuterAdmission) -> bool {
if !adm.gpu_available {
return false;
}
if !self.dense_hessian_work_target_is_gpu(adm.n, adm.p) {
return false;
}
if adm.num_rho < 2 {
return false;
}
match adm.family {
Some(_) => true,
None => false,
}
}
}
#[cfg(test)]
mod refinement_policy_tests {
use super::*;
#[test]
fn refinement_policy_admits_large_p() {
let pol = GpuDispatchPolicy::default();
assert!(pol.iterative_refinement_should_attempt(512));
assert!(pol.iterative_refinement_should_attempt(GpuDispatchPolicy::REFINEMENT_MIN_P));
}
#[test]
fn refinement_policy_rejects_small_p() {
let pol = GpuDispatchPolicy::default();
assert!(!pol.iterative_refinement_should_attempt(GpuDispatchPolicy::REFINEMENT_MIN_P - 1));
assert!(!pol.iterative_refinement_should_attempt(0));
}
#[test]
fn off_policy_never_attempts_refinement() {
let pol = GpuDispatchPolicy {
mixed_precision: GpuMixedPrecisionPolicy::Off,
..Default::default()
};
assert!(!pol.iterative_refinement_should_attempt(1024));
}
#[test]
fn never_policy_never_attempts_refinement() {
let pol = GpuDispatchPolicy {
mixed_precision: GpuMixedPrecisionPolicy::Never,
..Default::default()
};
assert!(!pol.iterative_refinement_should_attempt(1024));
}
}
#[cfg(test)]
mod reduced_schur_matvec_offload_tests {
use super::*;
#[test]
fn admits_llm_sae_matvec_shape() {
let pol = GpuDispatchPolicy::default();
assert!(pol.reduced_schur_matvec_should_offload(
2_000,
2_048,
8,
GpuDispatchPolicy::MATVEC_OFFLOAD_MIN_CG_ITERS,
));
assert!(!pol.dense_hessian_work_target_is_gpu(2_000, 8));
}
#[test]
fn admits_llm_shape_with_one_cg_iter() {
let pol = GpuDispatchPolicy::default();
assert!(pol.reduced_schur_matvec_should_offload(2_000, 2_048, 8, 1));
}
#[test]
fn admits_thin_curve_atoms_at_realistic_scale() {
let pol = GpuDispatchPolicy::default();
assert!(pol.reduced_schur_matvec_should_offload(24_576, 64, 1, 1));
assert!(pol.reduced_schur_matvec_should_offload(40_456, 256, 1, 1));
assert!(!pol.reduced_schur_matvec_should_offload(300, 6, 1, 8));
}
#[test]
fn rejects_tiny_shape_where_transfer_dominates() {
let pol = GpuDispatchPolicy::default();
assert!(!pol.reduced_schur_matvec_should_offload(
30,
8,
2,
GpuDispatchPolicy::MATVEC_OFFLOAD_MIN_CG_ITERS,
));
assert!(!pol.reduced_schur_matvec_should_offload(300, 8, 4, 16));
}
#[test]
fn rejects_narrow_border_even_with_huge_row_count() {
let pol = GpuDispatchPolicy::default();
let narrow = GpuDispatchPolicy::DEVICE_LOOP_MIN_P - 1;
assert!(!pol.reduced_schur_matvec_should_offload(1_000_000, narrow, 64, 64));
}
#[test]
fn rejects_degenerate_dimensions() {
let pol = GpuDispatchPolicy::default();
assert!(!pol.reduced_schur_matvec_should_offload(0, 2_048, 8, 8));
assert!(!pol.reduced_schur_matvec_should_offload(2_000, 0, 8, 8));
assert!(!pol.reduced_schur_matvec_should_offload(2_000, 2_048, 0, 8));
assert!(!pol.reduced_schur_matvec_should_offload(2_000, 2_048, 8, 0));
}
#[test]
fn monotone_in_cg_iters() {
let pol = GpuDispatchPolicy::default();
let (n, k, d) = (200usize, GpuDispatchPolicy::DEVICE_LOOP_MIN_P, 4usize);
assert!(!pol.reduced_schur_matvec_should_offload(n, k, d, 1));
assert!(pol.reduced_schur_matvec_should_offload(n, k, d, 1_000));
assert!(pol.reduced_schur_matvec_should_offload(n, k, d, 5_000));
}
#[test]
fn admission_lower_bound_undercounts_actual_work() {
for &(n, k, d) in &[
(2_000usize, 2_048usize, 8usize),
(200, GpuDispatchPolicy::DEVICE_LOOP_MIN_P, 4),
(1, 1, 1),
] {
let lower = GpuDispatchPolicy::admission_work_lower_bound(n, k, d);
let actual = (n as u128) * (4 * (d as u128) * (k as u128) + (d as u128) * (d as u128));
assert!(
lower < actual,
"admission lower bound {lower} must undercount actual work {actual} for ({n},{k},{d})"
);
}
}
}
#[cfg(test)]
mod arrow_border_solve_plan_tests {
use super::*;
#[test]
fn color_arm_recommends_reduced_iterative_and_flags_rank_deficiency() {
let pol = GpuDispatchPolicy::default();
let plan = pol.arrow_border_solve_plan(180, 15_360, 2, 30);
assert_eq!(plan.recommended, ArrowBorderStrategy::ReducedIterative);
assert!(plan.dense_border_rank_deficient);
assert_eq!(plan.data_fit_rank, 360);
assert!(plan.dense_direct_flops > plan.reduced_iterative_flops * 100);
assert!(plan.device_favorable);
}
#[test]
fn small_square_border_recommends_dense_direct() {
let pol = GpuDispatchPolicy::default();
let plan = pol.arrow_border_solve_plan(200, 64, 2, 8);
assert_eq!(plan.recommended, ArrowBorderStrategy::DenseDirect);
assert!(!plan.dense_border_rank_deficient);
assert_eq!(plan.data_fit_rank, 64);
}
#[test]
fn rank_flag_and_clamp_track_n_d_versus_k() {
let pol = GpuDispatchPolicy::default();
let exact = pol.arrow_border_solve_plan(50, 100, 2, 8);
assert!(!exact.dense_border_rank_deficient);
assert_eq!(exact.data_fit_rank, 100);
let deficient = pol.arrow_border_solve_plan(49, 100, 2, 8);
assert!(deficient.dense_border_rank_deficient);
assert_eq!(deficient.data_fit_rank, 98);
}
#[test]
fn wider_border_only_moves_toward_iterative() {
let pol = GpuDispatchPolicy::default();
let narrow = pol.arrow_border_solve_plan(200, 128, 4, 16);
let wide = pol.arrow_border_solve_plan(200, 8_192, 4, 16);
assert_eq!(wide.recommended, ArrowBorderStrategy::ReducedIterative);
let narrow_ratio = narrow.dense_direct_flops as f64 / narrow.reduced_iterative_flops as f64;
let wide_ratio = wide.dense_direct_flops as f64 / wide.reduced_iterative_flops as f64;
assert!(wide_ratio > narrow_ratio);
}
#[test]
fn larger_cg_budget_never_switches_away_from_dense() {
let pol = GpuDispatchPolicy::default();
let shape = (200usize, 96usize, 3usize);
let small = pol.arrow_border_solve_plan(shape.0, shape.1, shape.2, 4);
let large = pol.arrow_border_solve_plan(shape.0, shape.1, shape.2, 400);
if small.recommended == ArrowBorderStrategy::DenseDirect {
assert_eq!(large.recommended, ArrowBorderStrategy::DenseDirect);
}
assert!(large.reduced_iterative_flops >= small.reduced_iterative_flops);
}
#[test]
fn degenerate_shapes_are_trivial_dense_and_not_device_favorable() {
let pol = GpuDispatchPolicy::default();
for shape in [(0usize, 100usize, 2usize), (100, 0, 2), (100, 100, 0)] {
let plan = pol.arrow_border_solve_plan(shape.0, shape.1, shape.2, 8);
assert_eq!(plan.recommended, ArrowBorderStrategy::DenseDirect);
assert!(!plan.device_favorable);
assert_eq!(plan.dense_direct_flops, 0);
assert_eq!(plan.reduced_iterative_flops, 0);
}
}
#[test]
fn zero_cg_budget_is_treated_as_one_apply() {
let pol = GpuDispatchPolicy::default();
let plan = pol.arrow_border_solve_plan(180, 15_360, 2, 0);
assert_eq!(plan.cg_iters, 1);
assert!(plan.reduced_iterative_flops > 0);
}
}
#[cfg(test)]
mod encode_deployment_decision_tests {
use super::*;
#[test]
fn cpu_rate_can_never_meet_or_refute_the_target() {
let cpu_only = EncodeDeploymentDecision::blocked(EncodeDecisionBlocked::NoDevice);
assert!(cpu_only.is_undetermined());
assert!(!cpu_only.surrogate_unneeded());
assert!(!cpu_only.surrogate_justified());
let false_routed = EncodeDeploymentDecision::from_device_measurement(false, 1.0e9);
assert!(false_routed.is_undetermined());
assert!(!false_routed.surrogate_unneeded());
}
#[test]
fn engaged_measurement_decides_by_the_number() {
let target = GPU_THROUGHPUT_TARGET_ROWS_PER_SEC;
let met = EncodeDeploymentDecision::from_device_measurement(true, target * 2.0);
assert!(matches!(met, EncodeDeploymentDecision::Met { .. }));
assert!(met.surrogate_unneeded());
assert!(!met.surrogate_justified());
assert!(!met.is_undetermined());
let unmet = EncodeDeploymentDecision::from_device_measurement(true, target * 0.25);
assert!(matches!(unmet, EncodeDeploymentDecision::Unmet { .. }));
assert!(unmet.surrogate_justified());
assert!(!unmet.surrogate_unneeded());
let boundary = EncodeDeploymentDecision::from_device_measurement(true, target);
assert!(boundary.surrogate_unneeded());
}
#[test]
fn engaged_but_non_usable_rate_is_undetermined_not_a_pass() {
for bad in [0.0, -1.0, f64::NAN, f64::INFINITY] {
let d = EncodeDeploymentDecision::from_device_measurement(true, bad);
assert!(
d.is_undetermined(),
"an engaged-but-unusable rate {bad} must be Undetermined, not a decision"
);
assert!(!d.surrogate_unneeded());
assert!(!d.surrogate_justified());
}
}
#[test]
fn blocked_reasons_are_all_undetermined() {
for reason in [
EncodeDecisionBlocked::NoDevice,
EncodeDecisionBlocked::NoDeviceEncodeKernel,
EncodeDecisionBlocked::DeviceNotEngaged,
] {
let d = EncodeDeploymentDecision::blocked(reason);
assert!(d.is_undetermined());
assert!(!d.surrogate_unneeded());
assert!(!d.surrogate_justified());
}
}
}