Skip to main content

gam_gpu/
policy.rs

1use serde::{Deserialize, Serialize};
2
3#[derive(Clone, Copy, Debug, Eq, PartialEq, Serialize, Deserialize)]
4pub enum GpuMixedPrecisionPolicy {
5    /// Always use fp64 factorization; no refinement attempted.
6    Off,
7    /// Attempt fp32 Cholesky factorization followed by up to
8    /// `REFINEMENT_MAX_STEPS` fp64-residual refinement steps. Policy admits
9    /// the attempt only when `p ≥ REFINEMENT_MIN_P` (so that the fp64 GEMV
10    /// overhead is amortized) and the measured residual drops monotonically.
11    /// Falls back to fp64 factorization automatically when the residual does
12    /// not decrease (κ(A)·u ≥ 1 regime) or when the fp32 POTRF itself fails.
13    Refinement,
14    /// Always use fp64 factorization; equivalent to `Off` but signals that
15    /// an explicit policy decision was taken.
16    Never,
17}
18
19#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
20pub struct GpuDispatchPolicy {
21    pub xtwx_n_min: usize,
22    pub xtwx_flops_min: usize,
23    pub xtwx_use_fused_below_p: usize,
24    pub gemm_min_flops: usize,
25    pub potrf_min_p: usize,
26    pub small_dense_batched_potrf_max_p: usize,
27    pub small_dense_batched_potrf_min_batch: usize,
28    pub syevd_min_p: usize,
29    pub sparse_min_nnz: usize,
30    pub fused_kernel_min_n: usize,
31    pub keep_design_resident_min_bytes: usize,
32    pub prefer_gpu_factorization_min_p: usize,
33    pub row_kernel_min_n: usize,
34    pub mixed_precision: GpuMixedPrecisionPolicy,
35}
36
37impl Default for GpuDispatchPolicy {
38    /// Conservative seed thresholds used before device calibration and when
39    /// calibration cannot run on the current host.
40    ///
41    /// The production runtime replaces these with
42    /// [`crate::calibration::calibrated_policy_for_device`] after the CUDA
43    /// probe selects a concrete device. Keep these values conservative: they
44    /// are the typed baseline for CPU-only builds, failed calibration, and unit
45    /// tests that exercise policy predicates without initializing CUDA.
46    fn default() -> Self {
47        Self {
48            xtwx_n_min: 50_000,
49            xtwx_flops_min: 100_000_000,
50            xtwx_use_fused_below_p: 256,
51            gemm_min_flops: 100_000_000,
52            potrf_min_p: 512,
53            small_dense_batched_potrf_max_p: 32,
54            small_dense_batched_potrf_min_batch: 8,
55            syevd_min_p: 256,
56            sparse_min_nnz: 1_000_000,
57            fused_kernel_min_n: 100_000,
58            keep_design_resident_min_bytes: 32 * 1024 * 1024,
59            prefer_gpu_factorization_min_p: 512,
60            row_kernel_min_n: 50_000,
61            mixed_precision: GpuMixedPrecisionPolicy::Refinement,
62        }
63    }
64}
65
66impl GpuDispatchPolicy {
67    /// Minimum problem dimension for the fp32+refinement path.
68    ///
69    /// Below this threshold the fp64 GEMV needed for the residual check costs
70    /// more than the savings from fp32 factorization. The threshold is set so
71    /// that a single `p × p` DGEMV (2p² flops) is at least 10× cheaper than
72    /// the `p³/3` POTRF (i.e. p ≥ 64) while still leaving margin for the
73    /// POTRF/POTRS launches. In practice `p ≥ 64` matches the existing
74    /// `potrf_min_p = 512` floor for GPU dispatch, so the refinement path only
75    /// activates when the GPU factorization path is already chosen.
76    pub const REFINEMENT_MIN_P: usize = 64;
77
78    /// Maximum number of fp32-correction steps per solve.
79    ///
80    /// Two steps suffice for κ(A) ≤ 10⁵ at fp32 (u ≈ 6 × 10⁻⁸): after step
81    /// 1 the error is O(κ u)² ≈ 10⁻⁶, after step 2 it is O(κ u)⁴ ≈ 10⁻¹²,
82    /// which is well within the fp64 unit roundoff of 10⁻¹⁶ × κ. A cap of 3
83    /// is used defensively.
84    pub const REFINEMENT_MAX_STEPS: usize = 3;
85
86    /// Relative residual tolerance for declaring convergence.
87    ///
88    /// `‖r‖ / ‖b‖ ≤ tol` is considered a converged solve. 10⁻¹² is two
89    /// orders of magnitude above the fp64 machine epsilon times a moderate
90    /// condition number, leaving the policy conservative.
91    pub const REFINEMENT_TOL: f64 = 1e-12;
92
93    /// Return `true` when the policy and problem size together suggest that
94    /// attempting fp32 factorization + iterative refinement will be profitable.
95    ///
96    /// The predicate is conservative:
97    ///   * `GpuMixedPrecisionPolicy::Off` or `Never` → always `false`.
98    ///   * `Refinement` with `p < REFINEMENT_MIN_P` → `false` (GEMV overhead
99    ///     not amortised by fp32 POTRF savings below this threshold).
100    ///   * Otherwise `true`; the caller still falls back to fp64 factorization
101    ///     when the runtime fp32 POTRF fails or when the measured residual is
102    ///     non-monotone.
103    #[inline]
104    pub const fn iterative_refinement_should_attempt(&self, p: usize) -> bool {
105        match self.mixed_precision {
106            GpuMixedPrecisionPolicy::Off | GpuMixedPrecisionPolicy::Never => false,
107            GpuMixedPrecisionPolicy::Refinement => p >= Self::REFINEMENT_MIN_P,
108        }
109    }
110
111    pub const fn dense_gemv_target_is_gpu(&self, n: usize, p: usize, resident: bool) -> bool {
112        resident || n.saturating_mul(p).saturating_mul(2) >= self.gemm_min_flops
113    }
114
115    pub const fn xtwx_target_is_gpu(&self, n: usize, p: usize, materialized: bool) -> bool {
116        materialized && n > 0 && p > 0 && self.xtwx_flops(n, p) >= self.dense_reduction_flops_min()
117    }
118
119    pub const fn xtwy_target_is_gpu(
120        &self,
121        n: usize,
122        px: usize,
123        q: usize,
124        materialized: bool,
125    ) -> bool {
126        materialized
127            && n > 0
128            && px > 0
129            && q > 0
130            && self.xtwy_flops(n, px, q) >= self.dense_reduction_flops_min()
131    }
132
133    pub const fn potrf_target_is_gpu(&self, p: usize, h_resident: bool) -> bool {
134        h_resident && p >= self.potrf_min_p
135    }
136
137    pub const fn dense_hessian_work_target_is_gpu(&self, n: usize, p: usize) -> bool {
138        n > 0
139            && p >= Self::DEVICE_LOOP_MIN_P
140            && self.xtwx_flops(n, p) >= self.dense_reduction_flops_min()
141    }
142
143    const fn dense_reduction_flops_min(&self) -> u128 {
144        if self.xtwx_flops_min < self.gemm_min_flops {
145            self.xtwx_flops_min as u128
146        } else {
147            self.gemm_min_flops as u128
148        }
149    }
150
151    const fn xtwx_flops(&self, n: usize, p: usize) -> u128 {
152        2u128 * (n as u128) * (p as u128) * (p as u128)
153    }
154
155    const fn xtwy_flops(&self, n: usize, px: usize, q: usize) -> u128 {
156        2u128 * (n as u128) * (px as u128) * (q as u128)
157    }
158
159    /// Minimum total CG-amortised matvec flops below which the host↔device
160    /// transfer of the row frames + CG vectors is not repaid by the device
161    /// matvec, so the reduced-Schur PCG hot loop stays on the CPU.
162    ///
163    /// The dense-Direct path keys on `dense_reduction_flops_min` (a single big
164    /// factorization). The matrix-free SAE matvec is different: no single apply
165    /// trips that floor (each is a stack of `n` tiny `d×d` solves + sparse
166    /// `m·k` gather/scatter), but the *whole CG solve* runs the apply
167    /// `O(cg_iters)` times over the same resident frames. The device wins when
168    /// the **summed** matvec work over the solve exceeds the one-time staging
169    /// cost — so the gate keys on `cg_iters · per_apply_flops`, not one apply.
170    ///
171    /// Set one order of magnitude below the dense floor: the matvec frames stay
172    /// resident across CG iterations (uploaded once), so the per-flop transfer
173    /// amortization is `1/cg_iters` of a cold dense launch, and the breakeven
174    /// drops accordingly.
175    pub const MATVEC_OFFLOAD_FLOPS_MIN: u128 = 10_000_000;
176
177    /// Conservative seed for the reduced-Schur PCG iteration count when the
178    /// caller cannot supply a measured budget. InexactPCG on an SAE β-block of
179    /// width `k` converges in `O(√κ)` iterations; this floor keeps the work
180    /// estimate honest (≥ this many applies) without over-claiming a tight
181    /// solve. Used only to amortise the staging cost in the work estimate.
182    pub const MATVEC_OFFLOAD_MIN_CG_ITERS: usize = 8;
183
184    /// Per-apply flop estimate for one reduced-Schur matvec `S·x` of a
185    /// matrix-free SAE Kronecker system, as a pure function of the system shape.
186    ///
187    /// Per row block `i` the apply does: a forward cross-block GEMV
188    /// `v_i = H_tβ^(i)·x` (`≈ 2·d·k` multiply-adds, with the per-row latent
189    /// depth `d` as the M-frame width and `k` the border), a `d×d` triangular
190    /// solve through the cached Cholesky factor (`≈ d²`), and a transpose
191    /// cross-block GEMV `H_βt^(i)·w_i` (`≈ 2·d·k`). The two `2·d·k` GEMVs would
192    /// sum to `4·d·k`; this estimate deliberately undercounts to a single
193    /// `2·d·k` cross term as a conservative (lower-bound) admission floor, so
194    /// the apply is modelled as `≈ n·(2·d·k + d²)`. This is a deliberate
195    /// lower bound on the true `≈ n·(4·d·k + d²)` arithmetic — admitting a
196    /// shape under the smaller figure can only be more conservative, never
197    /// over-eager. It is keyed on the *frame depth* `d` (M) and border width
198    /// `k` (p), not row count alone, so LLM shapes (few rows, wide `k`, modest
199    /// `d`) register arithmetic the row-count gate misses.
200    ///
201    /// USE FOR DISPATCH GATING ONLY. This is **not** a flop count: it omits the
202    /// transpose cross-block GEMV (`2·d·k`), so it is a strict lower bound on the
203    /// true per-apply work `n·(4·d·k + d²)`. The gate can therefore only
204    /// under-admit, never over-admit. Do not reuse it for benchmark / speedup
205    /// accounting.
206    const fn admission_work_lower_bound(n: usize, k: usize, d: usize) -> u128 {
207        let n = n as u128;
208        let k = k as u128;
209        let d = d as u128;
210        // 2·d·k cross-block apply (forward only) + d² per-row solve — the
211        // transpose GEMV is intentionally dropped so this stays a lower bound.
212        n.saturating_mul(
213            2u128
214                .saturating_mul(d)
215                .saturating_mul(k)
216                .saturating_add(d * d),
217        )
218    }
219
220    /// Work-based admission for offloading the **reduced-Schur PCG matvec**
221    /// (the InexactPCG hot loop for matrix-free SAE β-blocks) to the device.
222    ///
223    /// This is the Phase-1 (#1017) re-keying: the dense gates key on row count
224    /// (`xtwx_n_min`, `row_kernel_min_n` at 50k) or a single big-factorization
225    /// flop floor, neither of which the SAE LLM shape trips — `(n≈2000) ×
226    /// (k≈2048) × (d≈8)` is *thousands of small dense ops*, no single op large,
227    /// so the row-count gate keeps the whole fit on one CPU core. Here the gate
228    /// is the **total batched work over the CG solve**:
229    ///
230    /// ```text
231    /// estimated_device_flops = cg_iters · per_apply_flops(n, k, d)
232    /// should_offload = estimated_device_flops ≥ T_breakeven
233    /// ```
234    ///
235    /// where `T_breakeven = MATVEC_OFFLOAD_FLOPS_MIN` accounts for the
236    /// host↔device staging of the row frames + CG vectors amortised over the
237    /// `cg_iters` applies that reuse the resident frames (so the per-flop
238    /// transfer cost is `1/cg_iters` of a cold launch, an order of magnitude
239    /// below the dense-Direct floor).
240    ///
241    /// Pure function of the shape: no device needed to evaluate, so it is unit-
242    /// testable. The caller still falls back to the bit-identical CPU matvec
243    /// whenever the backend build declines, so admitting a shape never changes
244    /// the numerics — only where the `Σ_i Y_iᵀ(Y_i x)` flops execute.
245    ///
246    /// * `n`        — number of row blocks (SAE observations / latent rows).
247    /// * `k`        — border β width (the SAE decoder atom count `K`).
248    /// * `d`        — per-row latent / active-frame depth (the M dimension).
249    /// * `cg_iters` — expected PCG iteration budget; the per-apply work is
250    ///   multiplied by this because the frames stay resident across iterations.
251    ///   Pass [`Self::MATVEC_OFFLOAD_MIN_CG_ITERS`] when no measured budget is
252    ///   available; a tighter (smaller) value only makes the gate stricter.
253    ///
254    /// ## Live arrow-Schur call site
255    ///
256    /// `crate::solver::arrow_schur::maybe_inject_gpu_schur_matvec` gates the
257    /// InexactPCG reduced-Schur matvec injection on this predicate:
258    /// `reduced_schur_matvec_should_offload(sys.rows.len(), sys.k, sys.d,
259    /// options.pcg.max_iterations.min(options.trust_region.max_iterations))`,
260    /// where `sys.d` is the system's max per-row latent depth and the iteration
261    /// budget is the same `max_iterations` the PCG loop launches with.
262    /// `try_device_arrow_direct` (the **dense** Direct point solve) correctly
263    /// keeps `dense_hessian_work_target_is_gpu`: that path is a single large
264    /// factorization, not the amortised matvec.
265    pub const fn reduced_schur_matvec_should_offload(
266        &self,
267        n: usize,
268        k: usize,
269        d: usize,
270        cg_iters: usize,
271    ) -> bool {
272        if n == 0 || k == 0 || d == 0 || cg_iters == 0 {
273            return false;
274        }
275        // The border width must clear the device-loop floor: below it the per-
276        // apply launch latency (one kernel sequence per matvec) dominates any
277        // arithmetic regardless of how many CG iterations run.
278        if k < Self::DEVICE_LOOP_MIN_P {
279            return false;
280        }
281        let per_apply = Self::admission_work_lower_bound(n, k, d);
282        let total = per_apply.saturating_mul(cg_iters as u128);
283        total >= Self::MATVEC_OFFLOAD_FLOPS_MIN
284    }
285}
286
287/// The aspirational single-GPU design-row throughput the #1412 decision gate is
288/// supposed to establish for the LLM-shape batched-Cholesky + tile-GEMM fit
289/// pipeline: 100 000 design rows processed per wall-clock second per device.
290///
291/// The original gate *claimed* this number without ever measuring it. The
292/// honest contract is the other way around: a benchmark
293/// (`examples/throughput_1412.rs`) measures the true rows/sec on a real device,
294/// and [`GpuThroughputVerdict::from_measurement`] reports whether the measured
295/// value meets the target — the verdict is a *function of the measurement*, not
296/// a hardcoded assertion. See `tests/owed_1412.rs`.
297pub const GPU_THROUGHPUT_TARGET_ROWS_PER_SEC: f64 = 100_000.0;
298
299/// Outcome of comparing a *measured* GPU throughput against the target. The
300/// only way to construct one is [`Self::from_measurement`], so a verdict can
301/// never assert a target that was not actually established by a measurement.
302#[derive(Clone, Copy, Debug, PartialEq)]
303pub struct GpuThroughputVerdict {
304    /// The measured design-rows-per-second on the device under test.
305    pub measured_rows_per_sec: f64,
306    /// The target the measurement is compared against.
307    pub target_rows_per_sec: f64,
308    /// `measured / target`. ≥ 1.0 means the target was established.
309    pub fraction_of_target: f64,
310    /// True iff `measured_rows_per_sec >= target_rows_per_sec`.
311    pub meets_target: bool,
312}
313
314impl GpuThroughputVerdict {
315    /// Build a verdict from a measured throughput against
316    /// [`GPU_THROUGHPUT_TARGET_ROWS_PER_SEC`]. A non-finite or non-positive
317    /// measurement can never meet the target (it is not a usable measurement).
318    #[inline]
319    pub fn from_measurement(measured_rows_per_sec: f64) -> Self {
320        Self::from_measurement_against(measured_rows_per_sec, GPU_THROUGHPUT_TARGET_ROWS_PER_SEC)
321    }
322
323    /// Build a verdict against an explicit target (used by tests that probe the
324    /// comparison logic without depending on the global target constant).
325    #[inline]
326    pub fn from_measurement_against(measured_rows_per_sec: f64, target_rows_per_sec: f64) -> Self {
327        let usable = measured_rows_per_sec.is_finite() && measured_rows_per_sec > 0.0;
328        let fraction_of_target = if usable && target_rows_per_sec > 0.0 {
329            measured_rows_per_sec / target_rows_per_sec
330        } else {
331            0.0
332        };
333        Self {
334            measured_rows_per_sec,
335            target_rows_per_sec,
336            fraction_of_target,
337            meets_target: usable && measured_rows_per_sec >= target_rows_per_sec,
338        }
339    }
340}
341
342/// Which `(response, link)` family the Stage 3.3 device-resident PIRLS loop
343/// can evaluate without going through the Level-B raw-body NVRTC path.
344///
345/// Mirrors `PirlsRowFamily::ALL` at the policy layer so the predicate stays
346/// linkable from the CPU PIRLS entry without dragging a Linux-only enum into
347/// every host compilation unit.
348#[derive(Clone, Copy, Debug, Eq, PartialEq)]
349pub enum PirlsLoopFamilyKind {
350    BernoulliLogit,
351    BernoulliProbit,
352    BernoulliCLogLog,
353    PoissonLog,
354    GaussianIdentity,
355    GammaLog,
356}
357
358#[derive(Clone, Copy, Debug, Eq, PartialEq)]
359pub enum PirlsLoopCurvatureKind {
360    Fisher,
361    Observed,
362}
363
364/// Inputs to [`should_run_reml_outer_on_device`]. The admission predicate
365/// for routing the *outer* REML BFGS-over-ρ loop onto a fully device-resident
366/// driver (rather than the host orchestrator that hops out per step).
367///
368/// Fields are intentionally lifted from data the CPU REML entry has on hand
369/// before it touches the seed generator or the inner P-IRLS loop, so the
370/// admission check is allocation-free and can short-circuit before any
371/// device call.
372#[derive(Clone, Copy, Debug)]
373pub struct RemlOuterAdmission {
374    /// Active design rows (post-transform).
375    pub n: usize,
376    /// Active design columns / penalised-Hessian dimension.
377    pub p: usize,
378    /// Number of smoothing parameters ρ the outer BFGS optimises over.
379    pub num_rho: usize,
380    /// Inner family / link pair the device-resident PIRLS loop can evaluate.
381    /// `None` means the family does not map onto the six JIT-cached row
382    /// kernels — the outer loop must stay on the host orchestrator because
383    /// the inner step would already hop out anyway.
384    pub family: Option<PirlsLoopFamilyKind>,
385    /// Curvature surface the inner loop will use; tied to `family` via
386    /// `pirls_loop_curvature_for`.
387    pub curvature: PirlsLoopCurvatureKind,
388    /// True when the CUDA runtime is initialised on this host.
389    pub gpu_available: bool,
390}
391
392/// Inputs to [`should_use_gpu_pirls_loop`]. Each field comes from data the
393/// CPU PIRLS entry has on hand before it touches the eigendecomposition
394/// engine, so the admission check itself is allocation-free and can short-
395/// circuit before any heavy work happens.
396#[derive(Clone, Copy, Debug)]
397pub struct PirlsLoopAdmission {
398    /// Number of rows in the active (post-transform) design matrix.
399    pub n: usize,
400    /// Number of columns in the active design (i.e. `p` of `Xᵀ X`).
401    pub p: usize,
402    /// `Some(_)` when the inner family maps onto one of the six JIT-cached
403    /// `PirlsRowFamily` variants; `None` for custom families that still
404    /// require Stage 6 Level B and have not yet been admitted here.
405    pub family: Option<PirlsLoopFamilyKind>,
406    /// Curvature surface the inner loop will use; the GPU loop has Fisher +
407    /// Observed kernels, anything else (e.g. expected-projection surrogates)
408    /// is not admitted.
409    pub curvature: PirlsLoopCurvatureKind,
410    /// True when the CUDA runtime is initialised on this host (i.e.
411    /// `GpuRuntime::global().is_some()`).
412    pub gpu_available: bool,
413}
414
415impl GpuDispatchPolicy {
416    /// Minimum design column count for the device-resident inner/outer loops.
417    ///
418    /// Below this width the per-iteration `XᵀWX + Cholesky` is dominated by
419    /// launch latency and PCIe staging rather than arithmetic, so the host LM
420    /// loop (which populates the full `PirlsResult` surface as a free
421    /// side-effect) is strictly cheaper. Shared by both the inner PIRLS and
422    /// outer REML admission predicates so they cannot drift apart.
423    pub const DEVICE_LOOP_MIN_P: usize = 32;
424
425    /// Conservative admission predicate for routing
426    /// `fit_model_for_fixed_rho_with_adaptive_kkt` through the Stage 3.3
427    /// device-resident PIRLS loop instead of the CPU LM loop.
428    ///
429    /// The threshold is the dense `XᵀWX` work estimate, not row count alone:
430    /// LLM/SAE fits can have only a few thousand rows but thousands of columns,
431    /// so `2*n*p^2` already dwarfs launch/staging overhead. Smaller fits stay on
432    /// the CPU LM loop where the full `PirlsResult` surface (firth, EDF,
433    /// per-row weights, …) is already populated as a free side-effect of the
434    /// iteration.
435    pub const fn should_use_gpu_pirls_loop(&self, adm: PirlsLoopAdmission) -> bool {
436        if !adm.gpu_available {
437            return false;
438        }
439        if !self.dense_hessian_work_target_is_gpu(adm.n, adm.p) {
440            return false;
441        }
442        match adm.family {
443            Some(_) => true,
444            None => false,
445        }
446    }
447
448    /// Admission predicate for routing the outer REML BFGS-over-ρ loop onto
449    /// a device-resident driver that keeps the BFGS state (ρ, gradient,
450    /// Hessian approx) on-device and only downloads the per-step scalar
451    /// metrics (objective value, gradient norm, convergence flag).
452    ///
453    /// The dense-work threshold piggybacks on the existing inner-PIRLS admission
454    /// predicate because the device-resident outer loop calls
455    /// `pirls_loop_on_stream` per step and must not pay the host hop for small
456    /// fits the inner loop would have rejected anyway. The
457    /// `num_rho ≥ 2` floor rules out the trivial single-smoother case where
458    /// host orchestration is already negligible and the device BFGS state
459    /// (one length-`num_rho` gradient + a `num_rho × num_rho` Hessian
460    /// approx) collapses to a couple of scalars not worth keeping on device.
461    pub const fn should_run_reml_outer_on_device(&self, adm: RemlOuterAdmission) -> bool {
462        if !adm.gpu_available {
463            return false;
464        }
465        if !self.dense_hessian_work_target_is_gpu(adm.n, adm.p) {
466            return false;
467        }
468        if adm.num_rho < 2 {
469            return false;
470        }
471        match adm.family {
472            Some(_) => true,
473            None => false,
474        }
475    }
476}
477
478#[cfg(test)]
479mod refinement_policy_tests {
480    use super::*;
481
482    #[test]
483    fn refinement_policy_admits_large_p() {
484        let pol = GpuDispatchPolicy::default();
485        // Default policy is Refinement; large p should be admitted.
486        assert!(pol.iterative_refinement_should_attempt(512));
487        assert!(pol.iterative_refinement_should_attempt(GpuDispatchPolicy::REFINEMENT_MIN_P));
488    }
489
490    #[test]
491    fn refinement_policy_rejects_small_p() {
492        let pol = GpuDispatchPolicy::default();
493        assert!(!pol.iterative_refinement_should_attempt(GpuDispatchPolicy::REFINEMENT_MIN_P - 1));
494        assert!(!pol.iterative_refinement_should_attempt(0));
495    }
496
497    #[test]
498    fn off_policy_never_attempts_refinement() {
499        let pol = GpuDispatchPolicy {
500            mixed_precision: GpuMixedPrecisionPolicy::Off,
501            ..Default::default()
502        };
503        assert!(!pol.iterative_refinement_should_attempt(1024));
504    }
505
506    #[test]
507    fn never_policy_never_attempts_refinement() {
508        let pol = GpuDispatchPolicy {
509            mixed_precision: GpuMixedPrecisionPolicy::Never,
510            ..Default::default()
511        };
512        assert!(!pol.iterative_refinement_should_attempt(1024));
513    }
514}
515
516#[cfg(test)]
517mod reduced_schur_matvec_offload_tests {
518    use super::*;
519
520    /// The LLM/SAE shape the whole #1017 Phase-1 re-keying targets: a few
521    /// thousand row blocks, a *wide* border (decoder atom count in the
522    /// thousands), a modest per-row frame depth, and a realistic CG budget.
523    /// The row-count gate (50k) and the dense-Direct flop floor both miss this
524    /// "thousands of tiny dense ops" shape; the work-amortised matvec gate must
525    /// fire on it.
526    #[test]
527    fn admits_llm_sae_matvec_shape() {
528        let pol = GpuDispatchPolicy::default();
529        // n≈2000 rows, k≈2048 atoms, M≈8 frame depth — n is far below the 50k
530        // row gate, yet the summed CG matvec work is large.
531        assert!(pol.reduced_schur_matvec_should_offload(
532            2_000,
533            2_048,
534            8,
535            GpuDispatchPolicy::MATVEC_OFFLOAD_MIN_CG_ITERS,
536        ));
537        // The same shape would be rejected by the row-count-style dense gate,
538        // confirming the re-keying is what admits it.
539        assert!(!pol.dense_hessian_work_target_is_gpu(2_000, 8));
540    }
541
542    /// Even with only a single conservative CG iteration the wide LLM border
543    /// clears the breakeven (the per-apply work alone is `2_000·(2·8·2_048 +
544    /// 8²) ≈ 6.6e7` flops > 1e7 by the conservative `n·(2·d·k + d²)` model;
545    /// the true `n·(4·d·k + d²)` arithmetic is ≈1.3e8),
546    /// so the gate is not relying on an inflated iteration count.
547    #[test]
548    fn admits_llm_shape_with_one_cg_iter() {
549        let pol = GpuDispatchPolicy::default();
550        assert!(pol.reduced_schur_matvec_should_offload(2_000, 2_048, 8, 1));
551    }
552
553    /// Tiny shapes where the host↔device transfer dominates must stay on the
554    /// CPU: a handful of rows, a narrow border, shallow frames. The summed
555    /// matvec work is orders of magnitude below the staging breakeven.
556    #[test]
557    fn rejects_tiny_shape_where_transfer_dominates() {
558        let pol = GpuDispatchPolicy::default();
559        assert!(!pol.reduced_schur_matvec_should_offload(
560            30,
561            8,
562            2,
563            GpuDispatchPolicy::MATVEC_OFFLOAD_MIN_CG_ITERS,
564        ));
565        // The 300×8 shape the production seam tests use as the "stay CPU"
566        // canary is rejected here too.
567        assert!(!pol.reduced_schur_matvec_should_offload(300, 8, 4, 16));
568    }
569
570    /// A narrow border (k below the device-loop floor) is rejected regardless
571    /// of how much row/iteration work is piled on: per-apply launch latency
572    /// dominates a sub-`DEVICE_LOOP_MIN_P` border.
573    #[test]
574    fn rejects_narrow_border_even_with_huge_row_count() {
575        let pol = GpuDispatchPolicy::default();
576        let narrow = GpuDispatchPolicy::DEVICE_LOOP_MIN_P - 1;
577        assert!(!pol.reduced_schur_matvec_should_offload(1_000_000, narrow, 64, 64));
578    }
579
580    /// Degenerate dimensions are never offloaded (no work, or no solve).
581    #[test]
582    fn rejects_degenerate_dimensions() {
583        let pol = GpuDispatchPolicy::default();
584        assert!(!pol.reduced_schur_matvec_should_offload(0, 2_048, 8, 8));
585        assert!(!pol.reduced_schur_matvec_should_offload(2_000, 0, 8, 8));
586        assert!(!pol.reduced_schur_matvec_should_offload(2_000, 2_048, 0, 8));
587        assert!(!pol.reduced_schur_matvec_should_offload(2_000, 2_048, 8, 0));
588    }
589
590    /// The gate is monotone in the CG budget: once a shape is admitted at a
591    /// given iteration count it stays admitted for any larger count (more
592    /// applies over the same resident frames only improves amortization), and
593    /// a borderline shape crosses the breakeven as iterations grow.
594    #[test]
595    fn monotone_in_cg_iters() {
596        let pol = GpuDispatchPolicy::default();
597        // A border at the floor with shallow frames and few rows: per-apply
598        // work ~ n·(2·d·k + d²). Choose a shape that is below breakeven at 1
599        // iter but above it once enough iterations accumulate.
600        let (n, k, d) = (200usize, GpuDispatchPolicy::DEVICE_LOOP_MIN_P, 4usize);
601        // per_apply ≈ 200·(2·4·32 + 16) = 200·272 = 54_400 flops.
602        assert!(!pol.reduced_schur_matvec_should_offload(n, k, d, 1));
603        // Once the summed work clears 1e7 the gate fires; ~184 iters here.
604        assert!(pol.reduced_schur_matvec_should_offload(n, k, d, 1_000));
605        // Monotonicity: admitted at 1_000 ⇒ admitted at every larger budget.
606        assert!(pol.reduced_schur_matvec_should_offload(n, k, d, 5_000));
607    }
608
609    /// The admission lower bound must stay strictly below the true per-apply
610    /// work `n·(4·d·k + d²)` for any non-degenerate cross-block shape (it drops
611    /// the transpose GEMV). Treating the lower bound as a flop count would
612    /// over-report device speedups, so this asserts the gap is real.
613    #[test]
614    fn admission_lower_bound_undercounts_actual_work() {
615        for &(n, k, d) in &[
616            (2_000usize, 2_048usize, 8usize),
617            (200, GpuDispatchPolicy::DEVICE_LOOP_MIN_P, 4),
618            (1, 1, 1),
619        ] {
620            let lower = GpuDispatchPolicy::admission_work_lower_bound(n, k, d);
621            // True per-apply work models the full forward+transpose GEMV pair
622            // plus the d×d solve: n·(4·d·k + d²).
623            let actual = (n as u128) * (4 * (d as u128) * (k as u128) + (d as u128) * (d as u128));
624            assert!(
625                lower < actual,
626                "admission lower bound {lower} must undercount actual work {actual} for ({n},{k},{d})"
627            );
628        }
629    }
630}