gam_gpu/policy.rs
1use serde::{Deserialize, Serialize};
2
3#[derive(Clone, Copy, Debug, Eq, PartialEq, Serialize, Deserialize)]
4pub enum GpuMixedPrecisionPolicy {
5 /// Always use fp64 factorization; no refinement attempted.
6 Off,
7 /// Attempt fp32 Cholesky factorization followed by up to
8 /// `REFINEMENT_MAX_STEPS` fp64-residual refinement steps. Policy admits
9 /// the attempt only when `p ≥ REFINEMENT_MIN_P` (so that the fp64 GEMV
10 /// overhead is amortized) and the measured residual drops monotonically.
11 /// Falls back to fp64 factorization automatically when the residual does
12 /// not decrease (κ(A)·u ≥ 1 regime) or when the fp32 POTRF itself fails.
13 Refinement,
14 /// Always use fp64 factorization; equivalent to `Off` but signals that
15 /// an explicit policy decision was taken.
16 Never,
17}
18
19#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
20pub struct GpuDispatchPolicy {
21 pub xtwx_n_min: usize,
22 pub xtwx_flops_min: usize,
23 pub xtwx_use_fused_below_p: usize,
24 pub gemm_min_flops: usize,
25 pub potrf_min_p: usize,
26 pub small_dense_batched_potrf_max_p: usize,
27 pub small_dense_batched_potrf_min_batch: usize,
28 pub syevd_min_p: usize,
29 pub sparse_min_nnz: usize,
30 pub fused_kernel_min_n: usize,
31 pub keep_design_resident_min_bytes: usize,
32 pub prefer_gpu_factorization_min_p: usize,
33 pub row_kernel_min_n: usize,
34 pub mixed_precision: GpuMixedPrecisionPolicy,
35}
36
37impl Default for GpuDispatchPolicy {
38 /// Conservative seed thresholds used before device calibration and when
39 /// calibration cannot run on the current host.
40 ///
41 /// The production runtime replaces these with
42 /// [`crate::calibration::calibrated_policy_for_device`] after the CUDA
43 /// probe selects a concrete device. Keep these values conservative: they
44 /// are the typed baseline for CPU-only builds, failed calibration, and unit
45 /// tests that exercise policy predicates without initializing CUDA.
46 fn default() -> Self {
47 Self {
48 xtwx_n_min: 50_000,
49 xtwx_flops_min: 100_000_000,
50 xtwx_use_fused_below_p: 256,
51 gemm_min_flops: 100_000_000,
52 potrf_min_p: 512,
53 small_dense_batched_potrf_max_p: 32,
54 small_dense_batched_potrf_min_batch: 8,
55 syevd_min_p: 256,
56 sparse_min_nnz: 1_000_000,
57 fused_kernel_min_n: 100_000,
58 keep_design_resident_min_bytes: 32 * 1024 * 1024,
59 prefer_gpu_factorization_min_p: 512,
60 row_kernel_min_n: 50_000,
61 mixed_precision: GpuMixedPrecisionPolicy::Refinement,
62 }
63 }
64}
65
66impl GpuDispatchPolicy {
67 /// Minimum problem dimension for the fp32+refinement path.
68 ///
69 /// Below this threshold the fp64 GEMV needed for the residual check costs
70 /// more than the savings from fp32 factorization. The threshold is set so
71 /// that a single `p × p` DGEMV (2p² flops) is at least 10× cheaper than
72 /// the `p³/3` POTRF (i.e. p ≥ 64) while still leaving margin for the
73 /// POTRF/POTRS launches. In practice `p ≥ 64` matches the existing
74 /// `potrf_min_p = 512` floor for GPU dispatch, so the refinement path only
75 /// activates when the GPU factorization path is already chosen.
76 pub const REFINEMENT_MIN_P: usize = 64;
77
78 /// Maximum number of fp32-correction steps per solve.
79 ///
80 /// Two steps suffice for κ(A) ≤ 10⁵ at fp32 (u ≈ 6 × 10⁻⁸): after step
81 /// 1 the error is O(κ u)² ≈ 10⁻⁶, after step 2 it is O(κ u)⁴ ≈ 10⁻¹²,
82 /// which is well within the fp64 unit roundoff of 10⁻¹⁶ × κ. A cap of 3
83 /// is used defensively.
84 pub const REFINEMENT_MAX_STEPS: usize = 3;
85
86 /// Relative residual tolerance for declaring convergence.
87 ///
88 /// `‖r‖ / ‖b‖ ≤ tol` is considered a converged solve. 10⁻¹² is two
89 /// orders of magnitude above the fp64 machine epsilon times a moderate
90 /// condition number, leaving the policy conservative.
91 pub const REFINEMENT_TOL: f64 = 1e-12;
92
93 /// Return `true` when the policy and problem size together suggest that
94 /// attempting fp32 factorization + iterative refinement will be profitable.
95 ///
96 /// The predicate is conservative:
97 /// * `GpuMixedPrecisionPolicy::Off` or `Never` → always `false`.
98 /// * `Refinement` with `p < REFINEMENT_MIN_P` → `false` (GEMV overhead
99 /// not amortised by fp32 POTRF savings below this threshold).
100 /// * Otherwise `true`; the caller still falls back to fp64 factorization
101 /// when the runtime fp32 POTRF fails or when the measured residual is
102 /// non-monotone.
103 #[inline]
104 pub const fn iterative_refinement_should_attempt(&self, p: usize) -> bool {
105 match self.mixed_precision {
106 GpuMixedPrecisionPolicy::Off | GpuMixedPrecisionPolicy::Never => false,
107 GpuMixedPrecisionPolicy::Refinement => p >= Self::REFINEMENT_MIN_P,
108 }
109 }
110
111 pub const fn dense_gemv_target_is_gpu(&self, n: usize, p: usize, resident: bool) -> bool {
112 resident || n.saturating_mul(p).saturating_mul(2) >= self.gemm_min_flops
113 }
114
115 pub const fn xtwx_target_is_gpu(&self, n: usize, p: usize, materialized: bool) -> bool {
116 materialized && n > 0 && p > 0 && self.xtwx_flops(n, p) >= self.dense_reduction_flops_min()
117 }
118
119 pub const fn xtwy_target_is_gpu(
120 &self,
121 n: usize,
122 px: usize,
123 q: usize,
124 materialized: bool,
125 ) -> bool {
126 materialized
127 && n > 0
128 && px > 0
129 && q > 0
130 && self.xtwy_flops(n, px, q) >= self.dense_reduction_flops_min()
131 }
132
133 pub const fn potrf_target_is_gpu(&self, p: usize, h_resident: bool) -> bool {
134 h_resident && p >= self.potrf_min_p
135 }
136
137 pub const fn dense_hessian_work_target_is_gpu(&self, n: usize, p: usize) -> bool {
138 n > 0
139 && p >= Self::DEVICE_LOOP_MIN_P
140 && self.xtwx_flops(n, p) >= self.dense_reduction_flops_min()
141 }
142
143 const fn dense_reduction_flops_min(&self) -> u128 {
144 if self.xtwx_flops_min < self.gemm_min_flops {
145 self.xtwx_flops_min as u128
146 } else {
147 self.gemm_min_flops as u128
148 }
149 }
150
151 const fn xtwx_flops(&self, n: usize, p: usize) -> u128 {
152 2u128 * (n as u128) * (p as u128) * (p as u128)
153 }
154
155 const fn xtwy_flops(&self, n: usize, px: usize, q: usize) -> u128 {
156 2u128 * (n as u128) * (px as u128) * (q as u128)
157 }
158
159 /// Minimum total CG-amortised matvec flops below which the host↔device
160 /// transfer of the row frames + CG vectors is not repaid by the device
161 /// matvec, so the reduced-Schur PCG hot loop stays on the CPU.
162 ///
163 /// The dense-Direct path keys on `dense_reduction_flops_min` (a single big
164 /// factorization). The matrix-free SAE matvec is different: no single apply
165 /// trips that floor (each is a stack of `n` tiny `d×d` solves + sparse
166 /// `m·k` gather/scatter), but the *whole CG solve* runs the apply
167 /// `O(cg_iters)` times over the same resident frames. The device wins when
168 /// the **summed** matvec work over the solve exceeds the one-time staging
169 /// cost — so the gate keys on `cg_iters · per_apply_flops`, not one apply.
170 ///
171 /// Set one order of magnitude below the dense floor: the matvec frames stay
172 /// resident across CG iterations (uploaded once), so the per-flop transfer
173 /// amortization is `1/cg_iters` of a cold dense launch, and the breakeven
174 /// drops accordingly.
175 pub const MATVEC_OFFLOAD_FLOPS_MIN: u128 = 10_000_000;
176
177 /// Conservative seed for the reduced-Schur PCG iteration count when the
178 /// caller cannot supply a measured budget. InexactPCG on an SAE β-block of
179 /// width `k` converges in `O(√κ)` iterations; this floor keeps the work
180 /// estimate honest (≥ this many applies) without over-claiming a tight
181 /// solve. Used only to amortise the staging cost in the work estimate.
182 pub const MATVEC_OFFLOAD_MIN_CG_ITERS: usize = 8;
183
184 /// Per-apply flop estimate for one reduced-Schur matvec `S·x` of a
185 /// matrix-free SAE Kronecker system, as a pure function of the system shape.
186 ///
187 /// Per row block `i` the apply does: a forward cross-block GEMV
188 /// `v_i = H_tβ^(i)·x` (`≈ 2·d·k` multiply-adds, with the per-row latent
189 /// depth `d` as the M-frame width and `k` the border), a `d×d` triangular
190 /// solve through the cached Cholesky factor (`≈ d²`), and a transpose
191 /// cross-block GEMV `H_βt^(i)·w_i` (`≈ 2·d·k`). The two `2·d·k` GEMVs would
192 /// sum to `4·d·k`; this estimate deliberately undercounts to a single
193 /// `2·d·k` cross term as a conservative (lower-bound) admission floor, so
194 /// the apply is modelled as `≈ n·(2·d·k + d²)`. This is a deliberate
195 /// lower bound on the true `≈ n·(4·d·k + d²)` arithmetic — admitting a
196 /// shape under the smaller figure can only be more conservative, never
197 /// over-eager. It is keyed on the *frame depth* `d` (M) and border width
198 /// `k` (p), not row count alone, so LLM shapes (few rows, wide `k`, modest
199 /// `d`) register arithmetic the row-count gate misses.
200 ///
201 /// USE FOR DISPATCH GATING ONLY. This is **not** a flop count: it omits the
202 /// transpose cross-block GEMV (`2·d·k`), so it is a strict lower bound on the
203 /// true per-apply work `n·(4·d·k + d²)`. The gate can therefore only
204 /// under-admit, never over-admit. Do not reuse it for benchmark / speedup
205 /// accounting.
206 const fn admission_work_lower_bound(n: usize, k: usize, d: usize) -> u128 {
207 let n = n as u128;
208 let k = k as u128;
209 let d = d as u128;
210 // 2·d·k cross-block apply (forward only) + d² per-row solve — the
211 // transpose GEMV is intentionally dropped so this stays a lower bound.
212 n.saturating_mul(
213 2u128
214 .saturating_mul(d)
215 .saturating_mul(k)
216 .saturating_add(d * d),
217 )
218 }
219
220 /// Work-based admission for offloading the **reduced-Schur PCG matvec**
221 /// (the InexactPCG hot loop for matrix-free SAE β-blocks) to the device.
222 ///
223 /// This is the Phase-1 (#1017) re-keying: the dense gates key on row count
224 /// (`xtwx_n_min`, `row_kernel_min_n` at 50k) or a single big-factorization
225 /// flop floor, neither of which the SAE LLM shape trips — `(n≈2000) ×
226 /// (k≈2048) × (d≈8)` is *thousands of small dense ops*, no single op large,
227 /// so the row-count gate keeps the whole fit on one CPU core. Here the gate
228 /// is the **total batched work over the CG solve**:
229 ///
230 /// ```text
231 /// estimated_device_flops = cg_iters · per_apply_flops(n, k, d)
232 /// should_offload = estimated_device_flops ≥ T_breakeven
233 /// ```
234 ///
235 /// where `T_breakeven = MATVEC_OFFLOAD_FLOPS_MIN` accounts for the
236 /// host↔device staging of the row frames + CG vectors amortised over the
237 /// `cg_iters` applies that reuse the resident frames (so the per-flop
238 /// transfer cost is `1/cg_iters` of a cold launch, an order of magnitude
239 /// below the dense-Direct floor).
240 ///
241 /// Pure function of the shape: no device needed to evaluate, so it is unit-
242 /// testable. The caller still falls back to the bit-identical CPU matvec
243 /// whenever the backend build declines, so admitting a shape never changes
244 /// the numerics — only where the `Σ_i Y_iᵀ(Y_i x)` flops execute.
245 ///
246 /// * `n` — number of row blocks (SAE observations / latent rows).
247 /// * `k` — border β width (the SAE decoder atom count `K`).
248 /// * `d` — per-row latent / active-frame depth (the M dimension).
249 /// * `cg_iters` — expected PCG iteration budget; the per-apply work is
250 /// multiplied by this because the frames stay resident across iterations.
251 /// Pass [`Self::MATVEC_OFFLOAD_MIN_CG_ITERS`] when no measured budget is
252 /// available; a tighter (smaller) value only makes the gate stricter.
253 ///
254 /// ## Live arrow-Schur call site
255 ///
256 /// `crate::solver::arrow_schur::maybe_inject_gpu_schur_matvec` gates the
257 /// InexactPCG reduced-Schur matvec injection on this predicate:
258 /// `reduced_schur_matvec_should_offload(sys.rows.len(), sys.k, sys.d,
259 /// options.pcg.max_iterations.min(options.trust_region.max_iterations))`,
260 /// where `sys.d` is the system's max per-row latent depth and the iteration
261 /// budget is the same `max_iterations` the PCG loop launches with.
262 /// `try_device_arrow_direct` (the **dense** Direct point solve) correctly
263 /// keeps `dense_hessian_work_target_is_gpu`: that path is a single large
264 /// factorization, not the amortised matvec.
265 pub const fn reduced_schur_matvec_should_offload(
266 &self,
267 n: usize,
268 k: usize,
269 d: usize,
270 cg_iters: usize,
271 ) -> bool {
272 if n == 0 || k == 0 || d == 0 || cg_iters == 0 {
273 return false;
274 }
275 // The border width must clear the device-loop floor: below it the per-
276 // apply launch latency (one kernel sequence per matvec) dominates any
277 // arithmetic regardless of how many CG iterations run.
278 if k < Self::DEVICE_LOOP_MIN_P {
279 return false;
280 }
281 let per_apply = Self::admission_work_lower_bound(n, k, d);
282 let total = per_apply.saturating_mul(cg_iters as u128);
283 total >= Self::MATVEC_OFFLOAD_FLOPS_MIN
284 }
285}
286
287/// The aspirational single-GPU design-row throughput the #1412 decision gate is
288/// supposed to establish for the LLM-shape batched-Cholesky + tile-GEMM fit
289/// pipeline: 100 000 design rows processed per wall-clock second per device.
290///
291/// The original gate *claimed* this number without ever measuring it. The
292/// honest contract is the other way around: a benchmark
293/// (`examples/throughput_1412.rs`) measures the true rows/sec on a real device,
294/// and [`GpuThroughputVerdict::from_measurement`] reports whether the measured
295/// value meets the target — the verdict is a *function of the measurement*, not
296/// a hardcoded assertion. See `tests/owed_1412.rs`.
297pub const GPU_THROUGHPUT_TARGET_ROWS_PER_SEC: f64 = 100_000.0;
298
299/// Outcome of comparing a *measured* GPU throughput against the target. The
300/// only way to construct one is [`Self::from_measurement`], so a verdict can
301/// never assert a target that was not actually established by a measurement.
302#[derive(Clone, Copy, Debug, PartialEq)]
303pub struct GpuThroughputVerdict {
304 /// The measured design-rows-per-second on the device under test.
305 pub measured_rows_per_sec: f64,
306 /// The target the measurement is compared against.
307 pub target_rows_per_sec: f64,
308 /// `measured / target`. ≥ 1.0 means the target was established.
309 pub fraction_of_target: f64,
310 /// True iff `measured_rows_per_sec >= target_rows_per_sec`.
311 pub meets_target: bool,
312}
313
314impl GpuThroughputVerdict {
315 /// Build a verdict from a measured throughput against
316 /// [`GPU_THROUGHPUT_TARGET_ROWS_PER_SEC`]. A non-finite or non-positive
317 /// measurement can never meet the target (it is not a usable measurement).
318 #[inline]
319 pub fn from_measurement(measured_rows_per_sec: f64) -> Self {
320 Self::from_measurement_against(measured_rows_per_sec, GPU_THROUGHPUT_TARGET_ROWS_PER_SEC)
321 }
322
323 /// Build a verdict against an explicit target (used by tests that probe the
324 /// comparison logic without depending on the global target constant).
325 #[inline]
326 pub fn from_measurement_against(measured_rows_per_sec: f64, target_rows_per_sec: f64) -> Self {
327 let usable = measured_rows_per_sec.is_finite() && measured_rows_per_sec > 0.0;
328 let fraction_of_target = if usable && target_rows_per_sec > 0.0 {
329 measured_rows_per_sec / target_rows_per_sec
330 } else {
331 0.0
332 };
333 Self {
334 measured_rows_per_sec,
335 target_rows_per_sec,
336 fraction_of_target,
337 meets_target: usable && measured_rows_per_sec >= target_rows_per_sec,
338 }
339 }
340}
341
342/// Which `(response, link)` family the Stage 3.3 device-resident PIRLS loop
343/// can evaluate without going through the Level-B raw-body NVRTC path.
344///
345/// Mirrors `PirlsRowFamily::ALL` at the policy layer so the predicate stays
346/// linkable from the CPU PIRLS entry without dragging a Linux-only enum into
347/// every host compilation unit.
348#[derive(Clone, Copy, Debug, Eq, PartialEq)]
349pub enum PirlsLoopFamilyKind {
350 BernoulliLogit,
351 BernoulliProbit,
352 BernoulliCLogLog,
353 PoissonLog,
354 GaussianIdentity,
355 GammaLog,
356}
357
358#[derive(Clone, Copy, Debug, Eq, PartialEq)]
359pub enum PirlsLoopCurvatureKind {
360 Fisher,
361 Observed,
362}
363
364/// Inputs to [`should_run_reml_outer_on_device`]. The admission predicate
365/// for routing the *outer* REML BFGS-over-ρ loop onto a fully device-resident
366/// driver (rather than the host orchestrator that hops out per step).
367///
368/// Fields are intentionally lifted from data the CPU REML entry has on hand
369/// before it touches the seed generator or the inner P-IRLS loop, so the
370/// admission check is allocation-free and can short-circuit before any
371/// device call.
372#[derive(Clone, Copy, Debug)]
373pub struct RemlOuterAdmission {
374 /// Active design rows (post-transform).
375 pub n: usize,
376 /// Active design columns / penalised-Hessian dimension.
377 pub p: usize,
378 /// Number of smoothing parameters ρ the outer BFGS optimises over.
379 pub num_rho: usize,
380 /// Inner family / link pair the device-resident PIRLS loop can evaluate.
381 /// `None` means the family does not map onto the six JIT-cached row
382 /// kernels — the outer loop must stay on the host orchestrator because
383 /// the inner step would already hop out anyway.
384 pub family: Option<PirlsLoopFamilyKind>,
385 /// Curvature surface the inner loop will use; tied to `family` via
386 /// `pirls_loop_curvature_for`.
387 pub curvature: PirlsLoopCurvatureKind,
388 /// True when the CUDA runtime is initialised on this host.
389 pub gpu_available: bool,
390}
391
392/// Inputs to [`should_use_gpu_pirls_loop`]. Each field comes from data the
393/// CPU PIRLS entry has on hand before it touches the eigendecomposition
394/// engine, so the admission check itself is allocation-free and can short-
395/// circuit before any heavy work happens.
396#[derive(Clone, Copy, Debug)]
397pub struct PirlsLoopAdmission {
398 /// Number of rows in the active (post-transform) design matrix.
399 pub n: usize,
400 /// Number of columns in the active design (i.e. `p` of `Xᵀ X`).
401 pub p: usize,
402 /// `Some(_)` when the inner family maps onto one of the six JIT-cached
403 /// `PirlsRowFamily` variants; `None` for custom families that still
404 /// require Stage 6 Level B and have not yet been admitted here.
405 pub family: Option<PirlsLoopFamilyKind>,
406 /// Curvature surface the inner loop will use; the GPU loop has Fisher +
407 /// Observed kernels, anything else (e.g. expected-projection surrogates)
408 /// is not admitted.
409 pub curvature: PirlsLoopCurvatureKind,
410 /// True when the CUDA runtime is initialised on this host (i.e.
411 /// `GpuRuntime::global().is_some()`).
412 pub gpu_available: bool,
413}
414
415impl GpuDispatchPolicy {
416 /// Minimum design column count for the device-resident inner/outer loops.
417 ///
418 /// Below this width the per-iteration `XᵀWX + Cholesky` is dominated by
419 /// launch latency and PCIe staging rather than arithmetic, so the host LM
420 /// loop (which populates the full `PirlsResult` surface as a free
421 /// side-effect) is strictly cheaper. Shared by both the inner PIRLS and
422 /// outer REML admission predicates so they cannot drift apart.
423 pub const DEVICE_LOOP_MIN_P: usize = 32;
424
425 /// Conservative admission predicate for routing
426 /// `fit_model_for_fixed_rho_with_adaptive_kkt` through the Stage 3.3
427 /// device-resident PIRLS loop instead of the CPU LM loop.
428 ///
429 /// The threshold is the dense `XᵀWX` work estimate, not row count alone:
430 /// LLM/SAE fits can have only a few thousand rows but thousands of columns,
431 /// so `2*n*p^2` already dwarfs launch/staging overhead. Smaller fits stay on
432 /// the CPU LM loop where the full `PirlsResult` surface (firth, EDF,
433 /// per-row weights, …) is already populated as a free side-effect of the
434 /// iteration.
435 pub const fn should_use_gpu_pirls_loop(&self, adm: PirlsLoopAdmission) -> bool {
436 if !adm.gpu_available {
437 return false;
438 }
439 if !self.dense_hessian_work_target_is_gpu(adm.n, adm.p) {
440 return false;
441 }
442 match adm.family {
443 Some(_) => true,
444 None => false,
445 }
446 }
447
448 /// Admission predicate for routing the outer REML BFGS-over-ρ loop onto
449 /// a device-resident driver that keeps the BFGS state (ρ, gradient,
450 /// Hessian approx) on-device and only downloads the per-step scalar
451 /// metrics (objective value, gradient norm, convergence flag).
452 ///
453 /// The dense-work threshold piggybacks on the existing inner-PIRLS admission
454 /// predicate because the device-resident outer loop calls
455 /// `pirls_loop_on_stream` per step and must not pay the host hop for small
456 /// fits the inner loop would have rejected anyway. The
457 /// `num_rho ≥ 2` floor rules out the trivial single-smoother case where
458 /// host orchestration is already negligible and the device BFGS state
459 /// (one length-`num_rho` gradient + a `num_rho × num_rho` Hessian
460 /// approx) collapses to a couple of scalars not worth keeping on device.
461 pub const fn should_run_reml_outer_on_device(&self, adm: RemlOuterAdmission) -> bool {
462 if !adm.gpu_available {
463 return false;
464 }
465 if !self.dense_hessian_work_target_is_gpu(adm.n, adm.p) {
466 return false;
467 }
468 if adm.num_rho < 2 {
469 return false;
470 }
471 match adm.family {
472 Some(_) => true,
473 None => false,
474 }
475 }
476}
477
478#[cfg(test)]
479mod refinement_policy_tests {
480 use super::*;
481
482 #[test]
483 fn refinement_policy_admits_large_p() {
484 let pol = GpuDispatchPolicy::default();
485 // Default policy is Refinement; large p should be admitted.
486 assert!(pol.iterative_refinement_should_attempt(512));
487 assert!(pol.iterative_refinement_should_attempt(GpuDispatchPolicy::REFINEMENT_MIN_P));
488 }
489
490 #[test]
491 fn refinement_policy_rejects_small_p() {
492 let pol = GpuDispatchPolicy::default();
493 assert!(!pol.iterative_refinement_should_attempt(GpuDispatchPolicy::REFINEMENT_MIN_P - 1));
494 assert!(!pol.iterative_refinement_should_attempt(0));
495 }
496
497 #[test]
498 fn off_policy_never_attempts_refinement() {
499 let pol = GpuDispatchPolicy {
500 mixed_precision: GpuMixedPrecisionPolicy::Off,
501 ..Default::default()
502 };
503 assert!(!pol.iterative_refinement_should_attempt(1024));
504 }
505
506 #[test]
507 fn never_policy_never_attempts_refinement() {
508 let pol = GpuDispatchPolicy {
509 mixed_precision: GpuMixedPrecisionPolicy::Never,
510 ..Default::default()
511 };
512 assert!(!pol.iterative_refinement_should_attempt(1024));
513 }
514}
515
516#[cfg(test)]
517mod reduced_schur_matvec_offload_tests {
518 use super::*;
519
520 /// The LLM/SAE shape the whole #1017 Phase-1 re-keying targets: a few
521 /// thousand row blocks, a *wide* border (decoder atom count in the
522 /// thousands), a modest per-row frame depth, and a realistic CG budget.
523 /// The row-count gate (50k) and the dense-Direct flop floor both miss this
524 /// "thousands of tiny dense ops" shape; the work-amortised matvec gate must
525 /// fire on it.
526 #[test]
527 fn admits_llm_sae_matvec_shape() {
528 let pol = GpuDispatchPolicy::default();
529 // n≈2000 rows, k≈2048 atoms, M≈8 frame depth — n is far below the 50k
530 // row gate, yet the summed CG matvec work is large.
531 assert!(pol.reduced_schur_matvec_should_offload(
532 2_000,
533 2_048,
534 8,
535 GpuDispatchPolicy::MATVEC_OFFLOAD_MIN_CG_ITERS,
536 ));
537 // The same shape would be rejected by the row-count-style dense gate,
538 // confirming the re-keying is what admits it.
539 assert!(!pol.dense_hessian_work_target_is_gpu(2_000, 8));
540 }
541
542 /// Even with only a single conservative CG iteration the wide LLM border
543 /// clears the breakeven (the per-apply work alone is `2_000·(2·8·2_048 +
544 /// 8²) ≈ 6.6e7` flops > 1e7 by the conservative `n·(2·d·k + d²)` model;
545 /// the true `n·(4·d·k + d²)` arithmetic is ≈1.3e8),
546 /// so the gate is not relying on an inflated iteration count.
547 #[test]
548 fn admits_llm_shape_with_one_cg_iter() {
549 let pol = GpuDispatchPolicy::default();
550 assert!(pol.reduced_schur_matvec_should_offload(2_000, 2_048, 8, 1));
551 }
552
553 /// Tiny shapes where the host↔device transfer dominates must stay on the
554 /// CPU: a handful of rows, a narrow border, shallow frames. The summed
555 /// matvec work is orders of magnitude below the staging breakeven.
556 #[test]
557 fn rejects_tiny_shape_where_transfer_dominates() {
558 let pol = GpuDispatchPolicy::default();
559 assert!(!pol.reduced_schur_matvec_should_offload(
560 30,
561 8,
562 2,
563 GpuDispatchPolicy::MATVEC_OFFLOAD_MIN_CG_ITERS,
564 ));
565 // The 300×8 shape the production seam tests use as the "stay CPU"
566 // canary is rejected here too.
567 assert!(!pol.reduced_schur_matvec_should_offload(300, 8, 4, 16));
568 }
569
570 /// A narrow border (k below the device-loop floor) is rejected regardless
571 /// of how much row/iteration work is piled on: per-apply launch latency
572 /// dominates a sub-`DEVICE_LOOP_MIN_P` border.
573 #[test]
574 fn rejects_narrow_border_even_with_huge_row_count() {
575 let pol = GpuDispatchPolicy::default();
576 let narrow = GpuDispatchPolicy::DEVICE_LOOP_MIN_P - 1;
577 assert!(!pol.reduced_schur_matvec_should_offload(1_000_000, narrow, 64, 64));
578 }
579
580 /// Degenerate dimensions are never offloaded (no work, or no solve).
581 #[test]
582 fn rejects_degenerate_dimensions() {
583 let pol = GpuDispatchPolicy::default();
584 assert!(!pol.reduced_schur_matvec_should_offload(0, 2_048, 8, 8));
585 assert!(!pol.reduced_schur_matvec_should_offload(2_000, 0, 8, 8));
586 assert!(!pol.reduced_schur_matvec_should_offload(2_000, 2_048, 0, 8));
587 assert!(!pol.reduced_schur_matvec_should_offload(2_000, 2_048, 8, 0));
588 }
589
590 /// The gate is monotone in the CG budget: once a shape is admitted at a
591 /// given iteration count it stays admitted for any larger count (more
592 /// applies over the same resident frames only improves amortization), and
593 /// a borderline shape crosses the breakeven as iterations grow.
594 #[test]
595 fn monotone_in_cg_iters() {
596 let pol = GpuDispatchPolicy::default();
597 // A border at the floor with shallow frames and few rows: per-apply
598 // work ~ n·(2·d·k + d²). Choose a shape that is below breakeven at 1
599 // iter but above it once enough iterations accumulate.
600 let (n, k, d) = (200usize, GpuDispatchPolicy::DEVICE_LOOP_MIN_P, 4usize);
601 // per_apply ≈ 200·(2·4·32 + 16) = 200·272 = 54_400 flops.
602 assert!(!pol.reduced_schur_matvec_should_offload(n, k, d, 1));
603 // Once the summed work clears 1e7 the gate fires; ~184 iters here.
604 assert!(pol.reduced_schur_matvec_should_offload(n, k, d, 1_000));
605 // Monotonicity: admitted at 1_000 ⇒ admitted at every larger budget.
606 assert!(pol.reduced_schur_matvec_should_offload(n, k, d, 5_000));
607 }
608
609 /// The admission lower bound must stay strictly below the true per-apply
610 /// work `n·(4·d·k + d²)` for any non-degenerate cross-block shape (it drops
611 /// the transpose GEMV). Treating the lower bound as a flop count would
612 /// over-report device speedups, so this asserts the gap is real.
613 #[test]
614 fn admission_lower_bound_undercounts_actual_work() {
615 for &(n, k, d) in &[
616 (2_000usize, 2_048usize, 8usize),
617 (200, GpuDispatchPolicy::DEVICE_LOOP_MIN_P, 4),
618 (1, 1, 1),
619 ] {
620 let lower = GpuDispatchPolicy::admission_work_lower_bound(n, k, d);
621 // True per-apply work models the full forward+transpose GEMV pair
622 // plus the d×d solve: n·(4·d·k + d²).
623 let actual = (n as u128) * (4 * (d as u128) * (k as u128) + (d as u128) * (d as u128));
624 assert!(
625 lower < actual,
626 "admission lower bound {lower} must undercount actual work {actual} for ({n},{k},{d})"
627 );
628 }
629 }
630}