pub struct GpuDispatchPolicy {Show 14 fields
pub xtwx_n_min: usize,
pub xtwx_flops_min: usize,
pub xtwx_use_fused_below_p: usize,
pub gemm_min_flops: usize,
pub potrf_min_p: usize,
pub small_dense_batched_potrf_max_p: usize,
pub small_dense_batched_potrf_min_batch: usize,
pub syevd_min_p: usize,
pub sparse_min_nnz: usize,
pub fused_kernel_min_n: usize,
pub keep_design_resident_min_bytes: usize,
pub prefer_gpu_factorization_min_p: usize,
pub row_kernel_min_n: usize,
pub mixed_precision: GpuMixedPrecisionPolicy,
}Fields§
§xtwx_n_min: usize§xtwx_flops_min: usize§xtwx_use_fused_below_p: usize§gemm_min_flops: usize§potrf_min_p: usize§small_dense_batched_potrf_max_p: usize§small_dense_batched_potrf_min_batch: usize§syevd_min_p: usize§sparse_min_nnz: usize§fused_kernel_min_n: usize§keep_design_resident_min_bytes: usize§prefer_gpu_factorization_min_p: usize§row_kernel_min_n: usize§mixed_precision: GpuMixedPrecisionPolicyImplementations§
Source§impl GpuDispatchPolicy
impl GpuDispatchPolicy
Sourcepub const REFINEMENT_MIN_P: usize = 64
pub const REFINEMENT_MIN_P: usize = 64
Minimum problem dimension for the fp32+refinement path.
Below this threshold the fp64 GEMV needed for the residual check costs
more than the savings from fp32 factorization. The threshold is set so
that a single p × p DGEMV (2p² flops) is at least 10× cheaper than
the p³/3 POTRF (i.e. p ≥ 64) while still leaving margin for the
POTRF/POTRS launches. In practice p ≥ 64 matches the existing
potrf_min_p = 512 floor for GPU dispatch, so the refinement path only
activates when the GPU factorization path is already chosen.
Sourcepub const REFINEMENT_MAX_STEPS: usize = 3
pub const REFINEMENT_MAX_STEPS: usize = 3
Maximum number of fp32-correction steps per solve.
Two steps suffice for κ(A) ≤ 10⁵ at fp32 (u ≈ 6 × 10⁻⁸): after step 1 the error is O(κ u)² ≈ 10⁻⁶, after step 2 it is O(κ u)⁴ ≈ 10⁻¹², which is well within the fp64 unit roundoff of 10⁻¹⁶ × κ. A cap of 3 is used defensively.
Sourcepub const REFINEMENT_TOL: f64 = 1e-12
pub const REFINEMENT_TOL: f64 = 1e-12
Relative residual tolerance for declaring convergence.
‖r‖ / ‖b‖ ≤ tol is considered a converged solve. 10⁻¹² is two
orders of magnitude above the fp64 machine epsilon times a moderate
condition number, leaving the policy conservative.
Sourcepub const MATVEC_OFFLOAD_FLOPS_MIN: u128 = 10_000_000
pub const MATVEC_OFFLOAD_FLOPS_MIN: u128 = 10_000_000
Minimum total CG-amortised matvec flops below which the host↔device transfer of the row frames + CG vectors is not repaid by the device matvec, so the reduced-Schur PCG hot loop stays on the CPU.
The dense-Direct path keys on dense_reduction_flops_min (a single big
factorization). The matrix-free SAE matvec is different: no single apply
trips that floor (each is a stack of n tiny d×d solves + sparse
m·k gather/scatter), but the whole CG solve runs the apply
O(cg_iters) times over the same resident frames. The device wins when
the summed matvec work over the solve exceeds the one-time staging
cost — so the gate keys on cg_iters · per_apply_flops, not one apply.
Set one order of magnitude below the dense floor: the matvec frames stay
resident across CG iterations (uploaded once), so the per-flop transfer
amortization is 1/cg_iters of a cold dense launch, and the breakeven
drops accordingly.
Sourcepub const MATVEC_OFFLOAD_MIN_CG_ITERS: usize = 8
pub const MATVEC_OFFLOAD_MIN_CG_ITERS: usize = 8
Conservative seed for the reduced-Schur PCG iteration count when the
caller cannot supply a measured budget. InexactPCG on an SAE β-block of
width k converges in O(√κ) iterations; this floor keeps the work
estimate honest (≥ this many applies) without over-claiming a tight
solve. Used only to amortise the staging cost in the work estimate.
Sourcepub const fn iterative_refinement_should_attempt(&self, p: usize) -> bool
pub const fn iterative_refinement_should_attempt(&self, p: usize) -> bool
Return true when the policy and problem size together suggest that
attempting fp32 factorization + iterative refinement will be profitable.
The predicate is conservative:
GpuMixedPrecisionPolicy::OfforNever→ alwaysfalse.Refinementwithp < REFINEMENT_MIN_P→false(GEMV overhead not amortised by fp32 POTRF savings below this threshold).- Otherwise
true; the caller still falls back to fp64 factorization when the runtime fp32 POTRF fails or when the measured residual is non-monotone.
pub const fn dense_gemv_target_is_gpu( &self, n: usize, p: usize, resident: bool, ) -> bool
pub const fn xtwx_target_is_gpu( &self, n: usize, p: usize, materialized: bool, ) -> bool
pub const fn xtwy_target_is_gpu( &self, n: usize, px: usize, q: usize, materialized: bool, ) -> bool
pub const fn potrf_target_is_gpu(&self, p: usize, h_resident: bool) -> bool
pub const fn dense_hessian_work_target_is_gpu(&self, n: usize, p: usize) -> bool
Sourcepub const fn reduced_schur_matvec_should_offload(
&self,
n: usize,
k: usize,
d: usize,
cg_iters: usize,
) -> bool
pub const fn reduced_schur_matvec_should_offload( &self, n: usize, k: usize, d: usize, cg_iters: usize, ) -> bool
Work-based admission for offloading the reduced-Schur PCG matvec (the InexactPCG hot loop for matrix-free SAE β-blocks) to the device.
This is the Phase-1 (#1017) re-keying: the dense gates key on row count
(xtwx_n_min, row_kernel_min_n at 50k) or a single big-factorization
flop floor, neither of which the SAE LLM shape trips — (n≈2000) × (k≈2048) × (d≈8) is thousands of small dense ops, no single op large,
so the row-count gate keeps the whole fit on one CPU core. Here the gate
is the total batched work over the CG solve:
estimated_device_flops = cg_iters · per_apply_flops(n, k, d)
should_offload = estimated_device_flops ≥ T_breakevenwhere T_breakeven = MATVEC_OFFLOAD_FLOPS_MIN accounts for the
host↔device staging of the row frames + CG vectors amortised over the
cg_iters applies that reuse the resident frames (so the per-flop
transfer cost is 1/cg_iters of a cold launch, an order of magnitude
below the dense-Direct floor).
Pure function of the shape: no device needed to evaluate, so it is unit-
testable. The caller still falls back to the bit-identical CPU matvec
whenever the backend build declines, so admitting a shape never changes
the numerics — only where the Σ_i Y_iᵀ(Y_i x) flops execute.
n— number of row blocks (SAE observations / latent rows).k— border β width (the SAE decoder atom countK).d— per-row latent / active-frame depth (the M dimension).cg_iters— expected PCG iteration budget; the per-apply work is multiplied by this because the frames stay resident across iterations. PassSelf::MATVEC_OFFLOAD_MIN_CG_ITERSwhen no measured budget is available; a tighter (smaller) value only makes the gate stricter.
§Live arrow-Schur call site
crate::solver::arrow_schur::maybe_inject_gpu_schur_matvec gates the
InexactPCG reduced-Schur matvec injection on this predicate:
reduced_schur_matvec_should_offload(sys.rows.len(), sys.k, sys.d, options.pcg.max_iterations.min(options.trust_region.max_iterations)),
where sys.d is the system’s max per-row latent depth and the iteration
budget is the same max_iterations the PCG loop launches with.
try_device_arrow_direct (the dense Direct point solve) correctly
keeps dense_hessian_work_target_is_gpu: that path is a single large
factorization, not the amortised matvec.
Source§impl GpuDispatchPolicy
impl GpuDispatchPolicy
Sourcepub const DEVICE_LOOP_MIN_P: usize = 32
pub const DEVICE_LOOP_MIN_P: usize = 32
Minimum design column count for the device-resident inner/outer loops.
Below this width the per-iteration XᵀWX + Cholesky is dominated by
launch latency and PCIe staging rather than arithmetic, so the host LM
loop (which populates the full PirlsResult surface as a free
side-effect) is strictly cheaper. Shared by both the inner PIRLS and
outer REML admission predicates so they cannot drift apart.
Sourcepub const fn should_use_gpu_pirls_loop(&self, adm: PirlsLoopAdmission) -> bool
pub const fn should_use_gpu_pirls_loop(&self, adm: PirlsLoopAdmission) -> bool
Conservative admission predicate for routing
fit_model_for_fixed_rho_with_adaptive_kkt through the Stage 3.3
device-resident PIRLS loop instead of the CPU LM loop.
The threshold is the dense XᵀWX work estimate, not row count alone:
LLM/SAE fits can have only a few thousand rows but thousands of columns,
so 2*n*p^2 already dwarfs launch/staging overhead. Smaller fits stay on
the CPU LM loop where the full PirlsResult surface (firth, EDF,
per-row weights, …) is already populated as a free side-effect of the
iteration.
Sourcepub const fn should_run_reml_outer_on_device(
&self,
adm: RemlOuterAdmission,
) -> bool
pub const fn should_run_reml_outer_on_device( &self, adm: RemlOuterAdmission, ) -> bool
Admission predicate for routing the outer REML BFGS-over-ρ loop onto a device-resident driver that keeps the BFGS state (ρ, gradient, Hessian approx) on-device and only downloads the per-step scalar metrics (objective value, gradient norm, convergence flag).
The dense-work threshold piggybacks on the existing inner-PIRLS admission
predicate because the device-resident outer loop calls
pirls_loop_on_stream per step and must not pay the host hop for small
fits the inner loop would have rejected anyway. The
num_rho ≥ 2 floor rules out the trivial single-smoother case where
host orchestration is already negligible and the device BFGS state
(one length-num_rho gradient + a num_rho × num_rho Hessian
approx) collapses to a couple of scalars not worth keeping on device.
Trait Implementations§
Source§impl Clone for GpuDispatchPolicy
impl Clone for GpuDispatchPolicy
Source§fn clone(&self) -> GpuDispatchPolicy
fn clone(&self) -> GpuDispatchPolicy
1.0.0 (const: unstable) · Source§fn clone_from(&mut self, source: &Self)
fn clone_from(&mut self, source: &Self)
source. Read moreSource§impl Debug for GpuDispatchPolicy
impl Debug for GpuDispatchPolicy
Source§impl Default for GpuDispatchPolicy
impl Default for GpuDispatchPolicy
Source§fn default() -> Self
fn default() -> Self
Conservative seed thresholds used before device calibration and when calibration cannot run on the current host.
The production runtime replaces these with
[crate::calibration::calibrated_policy_for_device] after the CUDA
probe selects a concrete device. Keep these values conservative: they
are the typed baseline for CPU-only builds, failed calibration, and unit
tests that exercise policy predicates without initializing CUDA.
Source§impl<'de> Deserialize<'de> for GpuDispatchPolicy
impl<'de> Deserialize<'de> for GpuDispatchPolicy
Source§fn deserialize<__D>(__deserializer: __D) -> Result<Self, __D::Error>where
__D: Deserializer<'de>,
fn deserialize<__D>(__deserializer: __D) -> Result<Self, __D::Error>where
__D: Deserializer<'de>,
impl Eq for GpuDispatchPolicy
Source§impl PartialEq for GpuDispatchPolicy
impl PartialEq for GpuDispatchPolicy
Source§fn eq(&self, other: &GpuDispatchPolicy) -> bool
fn eq(&self, other: &GpuDispatchPolicy) -> bool
self and other values to be equal, and is used by ==.Source§impl Serialize for GpuDispatchPolicy
impl Serialize for GpuDispatchPolicy
impl StructuralPartialEq for GpuDispatchPolicy
Auto Trait Implementations§
impl Freeze for GpuDispatchPolicy
impl RefUnwindSafe for GpuDispatchPolicy
impl Send for GpuDispatchPolicy
impl Sync for GpuDispatchPolicy
impl Unpin for GpuDispatchPolicy
impl UnsafeUnpin for GpuDispatchPolicy
impl UnwindSafe for GpuDispatchPolicy
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
impl<ST, DT> CastableFrom<ST, Initialized, Initialized> for DT
impl<ST, DT> CastableFrom<ST, Uninit, Uninit> for DT
Source§impl<T> CloneToUninit for Twhere
T: Clone,
impl<T> CloneToUninit for Twhere
T: Clone,
impl<T> DeserializeOwned for Twhere
T: for<'de> Deserialize<'de>,
Source§impl<T> DistributionExt for Twhere
T: ?Sized,
impl<T> DistributionExt for Twhere
T: ?Sized,
impl<T, U> Imply<T> for U
Source§impl<T> IntoEither for T
impl<T> IntoEither for T
Source§fn into_either(self, into_left: bool) -> Either<Self, Self>
fn into_either(self, into_left: bool) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left is true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read moreSource§fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
self into a Left variant of Either<Self, Self>
if into_left(&self) returns true.
Converts self into a Right variant of Either<Self, Self>
otherwise. Read more