Skip to main content

GpuDispatchPolicy

Struct GpuDispatchPolicy 

Source
pub struct GpuDispatchPolicy {
Show 14 fields pub xtwx_n_min: usize, pub xtwx_flops_min: usize, pub xtwx_use_fused_below_p: usize, pub gemm_min_flops: usize, pub potrf_min_p: usize, pub small_dense_batched_potrf_max_p: usize, pub small_dense_batched_potrf_min_batch: usize, pub syevd_min_p: usize, pub sparse_min_nnz: usize, pub fused_kernel_min_n: usize, pub keep_design_resident_min_bytes: usize, pub prefer_gpu_factorization_min_p: usize, pub row_kernel_min_n: usize, pub mixed_precision: GpuMixedPrecisionPolicy,
}

Fields§

§xtwx_n_min: usize§xtwx_flops_min: usize§xtwx_use_fused_below_p: usize§gemm_min_flops: usize§potrf_min_p: usize§small_dense_batched_potrf_max_p: usize§small_dense_batched_potrf_min_batch: usize§syevd_min_p: usize§sparse_min_nnz: usize§fused_kernel_min_n: usize§keep_design_resident_min_bytes: usize§prefer_gpu_factorization_min_p: usize§row_kernel_min_n: usize§mixed_precision: GpuMixedPrecisionPolicy

Implementations§

Source§

impl GpuDispatchPolicy

Source

pub const REFINEMENT_MIN_P: usize = 64

Minimum problem dimension for the fp32+refinement path.

Below this threshold the fp64 GEMV needed for the residual check costs more than the savings from fp32 factorization. The threshold is set so that a single p × p DGEMV (2p² flops) is at least 10× cheaper than the p³/3 POTRF (i.e. p ≥ 64) while still leaving margin for the POTRF/POTRS launches. In practice p ≥ 64 matches the existing potrf_min_p = 512 floor for GPU dispatch, so the refinement path only activates when the GPU factorization path is already chosen.

Source

pub const REFINEMENT_MAX_STEPS: usize = 3

Maximum number of fp32-correction steps per solve.

Two steps suffice for κ(A) ≤ 10⁵ at fp32 (u ≈ 6 × 10⁻⁸): after step 1 the error is O(κ u)² ≈ 10⁻⁶, after step 2 it is O(κ u)⁴ ≈ 10⁻¹², which is well within the fp64 unit roundoff of 10⁻¹⁶ × κ. A cap of 3 is used defensively.

Source

pub const REFINEMENT_TOL: f64 = 1e-12

Relative residual tolerance for declaring convergence.

‖r‖ / ‖b‖ ≤ tol is considered a converged solve. 10⁻¹² is two orders of magnitude above the fp64 machine epsilon times a moderate condition number, leaving the policy conservative.

Source

pub const MATVEC_OFFLOAD_FLOPS_MIN: u128 = 10_000_000

Minimum total CG-amortised matvec flops below which the host↔device transfer of the row frames + CG vectors is not repaid by the device matvec, so the reduced-Schur PCG hot loop stays on the CPU.

The dense-Direct path keys on dense_reduction_flops_min (a single big factorization). The matrix-free SAE matvec is different: no single apply trips that floor (each is a stack of n tiny d×d solves + sparse m·k gather/scatter), but the whole CG solve runs the apply O(cg_iters) times over the same resident frames. The device wins when the summed matvec work over the solve exceeds the one-time staging cost — so the gate keys on cg_iters · per_apply_flops, not one apply.

Set one order of magnitude below the dense floor: the matvec frames stay resident across CG iterations (uploaded once), so the per-flop transfer amortization is 1/cg_iters of a cold dense launch, and the breakeven drops accordingly.

Source

pub const MATVEC_OFFLOAD_MIN_CG_ITERS: usize = 8

Conservative seed for the reduced-Schur PCG iteration count when the caller cannot supply a measured budget. InexactPCG on an SAE β-block of width k converges in O(√κ) iterations; this floor keeps the work estimate honest (≥ this many applies) without over-claiming a tight solve. Used only to amortise the staging cost in the work estimate.

Source

pub const fn iterative_refinement_should_attempt(&self, p: usize) -> bool

Return true when the policy and problem size together suggest that attempting fp32 factorization + iterative refinement will be profitable.

The predicate is conservative:

  • GpuMixedPrecisionPolicy::Off or Never → always false.
  • Refinement with p < REFINEMENT_MIN_Pfalse (GEMV overhead not amortised by fp32 POTRF savings below this threshold).
  • Otherwise true; the caller still falls back to fp64 factorization when the runtime fp32 POTRF fails or when the measured residual is non-monotone.
Source

pub const fn dense_gemv_target_is_gpu( &self, n: usize, p: usize, resident: bool, ) -> bool

Source

pub const fn xtwx_target_is_gpu( &self, n: usize, p: usize, materialized: bool, ) -> bool

Source

pub const fn xtwy_target_is_gpu( &self, n: usize, px: usize, q: usize, materialized: bool, ) -> bool

Source

pub const fn potrf_target_is_gpu(&self, p: usize, h_resident: bool) -> bool

Source

pub const fn dense_hessian_work_target_is_gpu(&self, n: usize, p: usize) -> bool

Source

pub const fn reduced_schur_matvec_should_offload( &self, n: usize, k: usize, d: usize, cg_iters: usize, ) -> bool

Work-based admission for offloading the reduced-Schur PCG matvec (the InexactPCG hot loop for matrix-free SAE β-blocks) to the device.

This is the Phase-1 (#1017) re-keying: the dense gates key on row count (xtwx_n_min, row_kernel_min_n at 50k) or a single big-factorization flop floor, neither of which the SAE LLM shape trips — (n≈2000) × (k≈2048) × (d≈8) is thousands of small dense ops, no single op large, so the row-count gate keeps the whole fit on one CPU core. Here the gate is the total batched work over the CG solve:

estimated_device_flops = cg_iters · per_apply_flops(n, k, d)
should_offload = estimated_device_flops ≥ T_breakeven

where T_breakeven = MATVEC_OFFLOAD_FLOPS_MIN accounts for the host↔device staging of the row frames + CG vectors amortised over the cg_iters applies that reuse the resident frames (so the per-flop transfer cost is 1/cg_iters of a cold launch, an order of magnitude below the dense-Direct floor).

Pure function of the shape: no device needed to evaluate, so it is unit- testable. The caller still falls back to the bit-identical CPU matvec whenever the backend build declines, so admitting a shape never changes the numerics — only where the Σ_i Y_iᵀ(Y_i x) flops execute.

  • n — number of row blocks (SAE observations / latent rows).
  • k — border β width (the SAE decoder atom count K).
  • d — per-row latent / active-frame depth (the M dimension).
  • cg_iters — expected PCG iteration budget; the per-apply work is multiplied by this because the frames stay resident across iterations. Pass Self::MATVEC_OFFLOAD_MIN_CG_ITERS when no measured budget is available; a tighter (smaller) value only makes the gate stricter.
§Live arrow-Schur call site

crate::solver::arrow_schur::maybe_inject_gpu_schur_matvec gates the InexactPCG reduced-Schur matvec injection on this predicate: reduced_schur_matvec_should_offload(sys.rows.len(), sys.k, sys.d, options.pcg.max_iterations.min(options.trust_region.max_iterations)), where sys.d is the system’s max per-row latent depth and the iteration budget is the same max_iterations the PCG loop launches with. try_device_arrow_direct (the dense Direct point solve) correctly keeps dense_hessian_work_target_is_gpu: that path is a single large factorization, not the amortised matvec.

Source§

impl GpuDispatchPolicy

Source

pub const DEVICE_LOOP_MIN_P: usize = 32

Minimum design column count for the device-resident inner/outer loops.

Below this width the per-iteration XᵀWX + Cholesky is dominated by launch latency and PCIe staging rather than arithmetic, so the host LM loop (which populates the full PirlsResult surface as a free side-effect) is strictly cheaper. Shared by both the inner PIRLS and outer REML admission predicates so they cannot drift apart.

Source

pub const fn should_use_gpu_pirls_loop(&self, adm: PirlsLoopAdmission) -> bool

Conservative admission predicate for routing fit_model_for_fixed_rho_with_adaptive_kkt through the Stage 3.3 device-resident PIRLS loop instead of the CPU LM loop.

The threshold is the dense XᵀWX work estimate, not row count alone: LLM/SAE fits can have only a few thousand rows but thousands of columns, so 2*n*p^2 already dwarfs launch/staging overhead. Smaller fits stay on the CPU LM loop where the full PirlsResult surface (firth, EDF, per-row weights, …) is already populated as a free side-effect of the iteration.

Source

pub const fn should_run_reml_outer_on_device( &self, adm: RemlOuterAdmission, ) -> bool

Admission predicate for routing the outer REML BFGS-over-ρ loop onto a device-resident driver that keeps the BFGS state (ρ, gradient, Hessian approx) on-device and only downloads the per-step scalar metrics (objective value, gradient norm, convergence flag).

The dense-work threshold piggybacks on the existing inner-PIRLS admission predicate because the device-resident outer loop calls pirls_loop_on_stream per step and must not pay the host hop for small fits the inner loop would have rejected anyway. The num_rho ≥ 2 floor rules out the trivial single-smoother case where host orchestration is already negligible and the device BFGS state (one length-num_rho gradient + a num_rho × num_rho Hessian approx) collapses to a couple of scalars not worth keeping on device.

Trait Implementations§

Source§

impl Clone for GpuDispatchPolicy

Source§

fn clone(&self) -> GpuDispatchPolicy

Returns a duplicate of the value. Read more
1.0.0 (const: unstable) · Source§

fn clone_from(&mut self, source: &Self)

Performs copy-assignment from source. Read more
Source§

impl Debug for GpuDispatchPolicy

Source§

fn fmt(&self, f: &mut Formatter<'_>) -> Result

Formats the value using the given formatter. Read more
Source§

impl Default for GpuDispatchPolicy

Source§

fn default() -> Self

Conservative seed thresholds used before device calibration and when calibration cannot run on the current host.

The production runtime replaces these with [crate::calibration::calibrated_policy_for_device] after the CUDA probe selects a concrete device. Keep these values conservative: they are the typed baseline for CPU-only builds, failed calibration, and unit tests that exercise policy predicates without initializing CUDA.

Source§

impl<'de> Deserialize<'de> for GpuDispatchPolicy

Source§

fn deserialize<__D>(__deserializer: __D) -> Result<Self, __D::Error>
where __D: Deserializer<'de>,

Deserialize this value from the given Serde deserializer. Read more
Source§

impl Eq for GpuDispatchPolicy

Source§

impl PartialEq for GpuDispatchPolicy

Source§

fn eq(&self, other: &GpuDispatchPolicy) -> bool

Tests for self and other values to be equal, and is used by ==.
1.0.0 (const: unstable) · Source§

fn ne(&self, other: &Rhs) -> bool

Tests for !=. The default implementation is almost always sufficient, and should not be overridden without very good reason.
Source§

impl Serialize for GpuDispatchPolicy

Source§

fn serialize<__S>(&self, __serializer: __S) -> Result<__S::Ok, __S::Error>
where __S: Serializer,

Serialize this value into the given Serde serializer. Read more
Source§

impl StructuralPartialEq for GpuDispatchPolicy

Auto Trait Implementations§

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> ByRef<T> for T

Source§

fn by_ref(&self) -> &T

Source§

impl<ST, DT> CastableFrom<ST, Initialized, Initialized> for DT
where ST: ?Sized, DT: ?Sized,

Source§

impl<ST, DT> CastableFrom<ST, Uninit, Uninit> for DT
where ST: ?Sized, DT: ?Sized,

Source§

impl<T> CloneToUninit for T
where T: Clone,

Source§

unsafe fn clone_to_uninit(&self, dest: *mut u8)

🔬This is a nightly-only experimental API. (clone_to_uninit)
Performs copy-assignment from self to dest. Read more
Source§

impl<T> DeserializeOwned for T
where T: for<'de> Deserialize<'de>,

Source§

impl<T> DistributionExt for T
where T: ?Sized,

Source§

fn rand<T>(&self, rng: &mut (impl Rng + ?Sized)) -> T
where Self: Distribution<T>,

Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T, U> Imply<T> for U
where T: ?Sized, U: ?Sized,

Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T> IntoEither for T

Source§

fn into_either(self, into_left: bool) -> Either<Self, Self>

Converts self into a Left variant of Either<Self, Self> if into_left is true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

fn into_either_with<F>(self, into_left: F) -> Either<Self, Self>
where F: FnOnce(&Self) -> bool,

Converts self into a Left variant of Either<Self, Self> if into_left(&self) returns true. Converts self into a Right variant of Either<Self, Self> otherwise. Read more
Source§

impl<T> Pointable for T

Source§

const ALIGN: usize

The alignment of pointer.
Source§

type Init = T

The type for initializers.
Source§

unsafe fn init(init: <T as Pointable>::Init) -> usize

Initializes a with the given initializer. Read more
Source§

unsafe fn deref<'a>(ptr: usize) -> &'a T

Dereferences the given pointer. Read more
Source§

unsafe fn deref_mut<'a>(ptr: usize) -> &'a mut T

Mutably dereferences the given pointer. Read more
Source§

unsafe fn drop(ptr: usize)

Drops the object pointed to by the given pointer. Read more
Source§

impl<T> Read<Exclusive, BecauseExclusive> for T
where T: ?Sized,

Source§

impl<T> Same for T

Source§

type Output = T

Should always be Self
Source§

impl<T> ToOwned for T
where T: Clone,

Source§

type Owned = T

The resulting type after obtaining ownership.
Source§

fn to_owned(&self) -> T

Creates owned data from borrowed data, usually by cloning. Read more
Source§

fn clone_into(&self, target: &mut T)

Uses borrowed data to replace owned data, usually by cloning. Read more
Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.
Source§

impl<V, T> VZip<V> for T
where V: MultiLane<T>,

Source§

fn vzip(self) -> V