Skip to main content

gam_gpu/
policy.rs

1use serde::{Deserialize, Serialize};
2
3#[derive(Clone, Copy, Debug, Eq, PartialEq, Serialize, Deserialize)]
4pub enum GpuMixedPrecisionPolicy {
5    /// Always use fp64 factorization; no refinement attempted.
6    Off,
7    /// Attempt fp32 Cholesky factorization followed by up to
8    /// `REFINEMENT_MAX_STEPS` fp64-residual refinement steps. Policy admits
9    /// the attempt only when `p ≥ REFINEMENT_MIN_P` (so that the fp64 GEMV
10    /// overhead is amortized) and the measured residual drops monotonically.
11    /// Falls back to fp64 factorization automatically when the residual does
12    /// not decrease (κ(A)·u ≥ 1 regime) or when the fp32 POTRF itself fails.
13    Refinement,
14    /// Always use fp64 factorization; equivalent to `Off` but signals that
15    /// an explicit policy decision was taken.
16    Never,
17}
18
19#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
20pub struct GpuDispatchPolicy {
21    pub xtwx_n_min: usize,
22    pub xtwx_flops_min: usize,
23    pub xtwx_use_fused_below_p: usize,
24    pub gemm_min_flops: usize,
25    pub potrf_min_p: usize,
26    pub small_dense_batched_potrf_max_p: usize,
27    pub small_dense_batched_potrf_min_batch: usize,
28    pub syevd_min_p: usize,
29    pub sparse_min_nnz: usize,
30    pub fused_kernel_min_n: usize,
31    pub keep_design_resident_min_bytes: usize,
32    pub prefer_gpu_factorization_min_p: usize,
33    pub row_kernel_min_n: usize,
34    pub mixed_precision: GpuMixedPrecisionPolicy,
35}
36
37impl Default for GpuDispatchPolicy {
38    /// Conservative seed thresholds used before device calibration and when
39    /// calibration cannot run on the current host.
40    ///
41    /// The production runtime replaces these with
42    /// [`crate::calibration::calibrated_policy_for_device`] after the CUDA
43    /// probe selects a concrete device. Keep these values conservative: they
44    /// are the typed baseline for CPU-only builds, failed calibration, and unit
45    /// tests that exercise policy predicates without initializing CUDA.
46    fn default() -> Self {
47        Self {
48            xtwx_n_min: 50_000,
49            xtwx_flops_min: 100_000_000,
50            xtwx_use_fused_below_p: 256,
51            gemm_min_flops: 100_000_000,
52            potrf_min_p: 512,
53            small_dense_batched_potrf_max_p: 32,
54            small_dense_batched_potrf_min_batch: 8,
55            syevd_min_p: 256,
56            sparse_min_nnz: 1_000_000,
57            fused_kernel_min_n: 100_000,
58            keep_design_resident_min_bytes: 32 * 1024 * 1024,
59            prefer_gpu_factorization_min_p: 512,
60            row_kernel_min_n: 50_000,
61            mixed_precision: GpuMixedPrecisionPolicy::Refinement,
62        }
63    }
64}
65
66impl GpuDispatchPolicy {
67    /// Minimum problem dimension for the fp32+refinement path.
68    ///
69    /// Below this threshold the fp64 GEMV needed for the residual check costs
70    /// more than the savings from fp32 factorization. The threshold is set so
71    /// that a single `p × p` DGEMV (2p² flops) is at least 10× cheaper than
72    /// the `p³/3` POTRF (i.e. p ≥ 64) while still leaving margin for the
73    /// POTRF/POTRS launches. In practice `p ≥ 64` matches the existing
74    /// `potrf_min_p = 512` floor for GPU dispatch, so the refinement path only
75    /// activates when the GPU factorization path is already chosen.
76    pub const REFINEMENT_MIN_P: usize = 64;
77
78    /// Maximum number of fp32-correction steps per solve.
79    ///
80    /// Two steps suffice for κ(A) ≤ 10⁵ at fp32 (u ≈ 6 × 10⁻⁸): after step
81    /// 1 the error is O(κ u)² ≈ 10⁻⁶, after step 2 it is O(κ u)⁴ ≈ 10⁻¹²,
82    /// which is well within the fp64 unit roundoff of 10⁻¹⁶ × κ. A cap of 3
83    /// is used defensively.
84    pub const REFINEMENT_MAX_STEPS: usize = 3;
85
86    /// Relative residual tolerance for declaring convergence.
87    ///
88    /// `‖r‖ / ‖b‖ ≤ tol` is considered a converged solve. 10⁻¹² is two
89    /// orders of magnitude above the fp64 machine epsilon times a moderate
90    /// condition number, leaving the policy conservative.
91    pub const REFINEMENT_TOL: f64 = 1e-12;
92
93    /// Return `true` when the policy and problem size together suggest that
94    /// attempting fp32 factorization + iterative refinement will be profitable.
95    ///
96    /// The predicate is conservative:
97    ///   * `GpuMixedPrecisionPolicy::Off` or `Never` → always `false`.
98    ///   * `Refinement` with `p < REFINEMENT_MIN_P` → `false` (GEMV overhead
99    ///     not amortised by fp32 POTRF savings below this threshold).
100    ///   * Otherwise `true`; the caller still falls back to fp64 factorization
101    ///     when the runtime fp32 POTRF fails or when the measured residual is
102    ///     non-monotone.
103    #[inline]
104    pub const fn iterative_refinement_should_attempt(&self, p: usize) -> bool {
105        match self.mixed_precision {
106            GpuMixedPrecisionPolicy::Off | GpuMixedPrecisionPolicy::Never => false,
107            GpuMixedPrecisionPolicy::Refinement => p >= Self::REFINEMENT_MIN_P,
108        }
109    }
110
111    pub const fn dense_gemv_target_is_gpu(&self, n: usize, p: usize, resident: bool) -> bool {
112        resident || n.saturating_mul(p).saturating_mul(2) >= self.gemm_min_flops
113    }
114
115    pub const fn xtwx_target_is_gpu(&self, n: usize, p: usize, materialized: bool) -> bool {
116        materialized && n > 0 && p > 0 && self.xtwx_flops(n, p) >= self.dense_reduction_flops_min()
117    }
118
119    pub const fn xtwy_target_is_gpu(
120        &self,
121        n: usize,
122        px: usize,
123        q: usize,
124        materialized: bool,
125    ) -> bool {
126        materialized
127            && n > 0
128            && px > 0
129            && q > 0
130            && self.xtwy_flops(n, px, q) >= self.dense_reduction_flops_min()
131    }
132
133    pub const fn potrf_target_is_gpu(&self, p: usize, h_resident: bool) -> bool {
134        h_resident && p >= self.potrf_min_p
135    }
136
137    pub const fn dense_hessian_work_target_is_gpu(&self, n: usize, p: usize) -> bool {
138        n > 0
139            && p >= Self::DEVICE_LOOP_MIN_P
140            && self.xtwx_flops(n, p) >= self.dense_reduction_flops_min()
141    }
142
143    const fn dense_reduction_flops_min(&self) -> u128 {
144        if self.xtwx_flops_min < self.gemm_min_flops {
145            self.xtwx_flops_min as u128
146        } else {
147            self.gemm_min_flops as u128
148        }
149    }
150
151    const fn xtwx_flops(&self, n: usize, p: usize) -> u128 {
152        2u128 * (n as u128) * (p as u128) * (p as u128)
153    }
154
155    const fn xtwy_flops(&self, n: usize, px: usize, q: usize) -> u128 {
156        2u128 * (n as u128) * (px as u128) * (q as u128)
157    }
158
159    /// Minimum total CG-amortised matvec flops below which the host↔device
160    /// transfer of the row frames + CG vectors is not repaid by the device
161    /// matvec, so the reduced-Schur PCG hot loop stays on the CPU.
162    ///
163    /// The dense-Direct path keys on `dense_reduction_flops_min` (a single big
164    /// factorization). The matrix-free SAE matvec is different: no single apply
165    /// trips that floor (each is a stack of `n` tiny `d×d` solves + sparse
166    /// `m·k` gather/scatter), but the *whole CG solve* runs the apply
167    /// `O(cg_iters)` times over the same resident frames. The device wins when
168    /// the **summed** matvec work over the solve exceeds the one-time staging
169    /// cost — so the gate keys on `cg_iters · per_apply_flops`, not one apply.
170    ///
171    /// Set one order of magnitude below the dense floor: the matvec frames stay
172    /// resident across CG iterations (uploaded once), so the per-flop transfer
173    /// amortization is `1/cg_iters` of a cold dense launch, and the breakeven
174    /// drops accordingly.
175    pub const MATVEC_OFFLOAD_FLOPS_MIN: u128 = 10_000_000;
176
177    /// Conservative seed for the reduced-Schur PCG iteration count when the
178    /// caller cannot supply a measured budget. InexactPCG on an SAE β-block of
179    /// width `k` converges in `O(√κ)` iterations; this floor keeps the work
180    /// estimate honest (≥ this many applies) without over-claiming a tight
181    /// solve. Used only to amortise the staging cost in the work estimate.
182    pub const MATVEC_OFFLOAD_MIN_CG_ITERS: usize = 8;
183
184    /// Per-apply flop estimate for one reduced-Schur matvec `S·x` of a
185    /// matrix-free SAE Kronecker system, as a pure function of the system shape.
186    ///
187    /// Per row block `i` the apply does: a forward cross-block GEMV
188    /// `v_i = H_tβ^(i)·x` (`≈ 2·d·k` multiply-adds, with the per-row latent
189    /// depth `d` as the M-frame width and `k` the border), a `d×d` triangular
190    /// solve through the cached Cholesky factor (`≈ d²`), and a transpose
191    /// cross-block GEMV `H_βt^(i)·w_i` (`≈ 2·d·k`). The two `2·d·k` GEMVs would
192    /// sum to `4·d·k`; this estimate deliberately undercounts to a single
193    /// `2·d·k` cross term as a conservative (lower-bound) admission floor, so
194    /// the apply is modelled as `≈ n·(2·d·k + d²)`. This is a deliberate
195    /// lower bound on the true `≈ n·(4·d·k + d²)` arithmetic — admitting a
196    /// shape under the smaller figure can only be more conservative, never
197    /// over-eager. It is keyed on the *frame depth* `d` (M) and border width
198    /// `k` (p), not row count alone, so LLM shapes (few rows, wide `k`, modest
199    /// `d`) register arithmetic the row-count gate misses.
200    ///
201    /// USE FOR DISPATCH GATING ONLY. This is **not** a flop count: it omits the
202    /// transpose cross-block GEMV (`2·d·k`), so it is a strict lower bound on the
203    /// true per-apply work `n·(4·d·k + d²)`. The gate can therefore only
204    /// under-admit, never over-admit. Do not reuse it for benchmark / speedup
205    /// accounting.
206    const fn admission_work_lower_bound(n: usize, k: usize, d: usize) -> u128 {
207        let n = n as u128;
208        let k = k as u128;
209        let d = d as u128;
210        // 2·d·k cross-block apply (forward only) + d² per-row solve — the
211        // transpose GEMV is intentionally dropped so this stays a lower bound.
212        n.saturating_mul(
213            2u128
214                .saturating_mul(d)
215                .saturating_mul(k)
216                .saturating_add(d * d),
217        )
218    }
219
220    /// Work-based admission for offloading the **reduced-Schur PCG matvec**
221    /// (the InexactPCG hot loop for matrix-free SAE β-blocks) to the device.
222    ///
223    /// This is the Phase-1 (#1017) re-keying: the dense gates key on row count
224    /// (`xtwx_n_min`, `row_kernel_min_n` at 50k) or a single big-factorization
225    /// flop floor, neither of which the SAE LLM shape trips — `(n≈2000) ×
226    /// (k≈2048) × (d≈8)` is *thousands of small dense ops*, no single op large,
227    /// so the row-count gate keeps the whole fit on one CPU core. Here the gate
228    /// is the **total batched work over the CG solve**:
229    ///
230    /// ```text
231    /// estimated_device_flops = cg_iters · per_apply_flops(n, k, d)
232    /// should_offload = estimated_device_flops ≥ T_breakeven
233    /// ```
234    ///
235    /// where `T_breakeven = MATVEC_OFFLOAD_FLOPS_MIN` accounts for the
236    /// host↔device staging of the row frames + CG vectors amortised over the
237    /// `cg_iters` applies that reuse the resident frames (so the per-flop
238    /// transfer cost is `1/cg_iters` of a cold launch, an order of magnitude
239    /// below the dense-Direct floor).
240    ///
241    /// Pure function of the shape: no device needed to evaluate, so it is unit-
242    /// testable. The caller still falls back to the bit-identical CPU matvec
243    /// whenever the backend build declines, so admitting a shape never changes
244    /// the numerics — only where the `Σ_i Y_iᵀ(Y_i x)` flops execute.
245    ///
246    /// * `n`        — number of row blocks (SAE observations / latent rows).
247    /// * `k`        — border β width (the SAE decoder atom count `K`).
248    /// * `d`        — per-row latent / active-frame depth (the M dimension).
249    /// * `cg_iters` — expected PCG iteration budget; the per-apply work is
250    ///   multiplied by this because the frames stay resident across iterations.
251    ///   Pass [`Self::MATVEC_OFFLOAD_MIN_CG_ITERS`] when no measured budget is
252    ///   available; a tighter (smaller) value only makes the gate stricter.
253    ///
254    /// ## Live arrow-Schur call site
255    ///
256    /// `crate::solver::arrow_schur::maybe_inject_gpu_schur_matvec` gates the
257    /// InexactPCG reduced-Schur matvec injection on this predicate:
258    /// `reduced_schur_matvec_should_offload(sys.rows.len(), sys.k, sys.d,
259    /// options.pcg.max_iterations.min(options.trust_region.max_iterations))`,
260    /// where `sys.d` is the system's max per-row latent depth and the iteration
261    /// budget is the same `max_iterations` the PCG loop launches with.
262    /// `try_device_arrow_direct` (the **dense** Direct point solve) correctly
263    /// keeps `dense_hessian_work_target_is_gpu`: that path is a single large
264    /// factorization, not the amortised matvec.
265    pub const fn reduced_schur_matvec_should_offload(
266        &self,
267        n: usize,
268        k: usize,
269        d: usize,
270        cg_iters: usize,
271    ) -> bool {
272        if n == 0 || k == 0 || d == 0 || cg_iters == 0 {
273            return false;
274        }
275        // The border width must clear the device-loop floor: below it the per-
276        // apply launch latency (one kernel sequence per matvec) dominates any
277        // arithmetic regardless of how many CG iterations run.
278        if k < Self::DEVICE_LOOP_MIN_P {
279            return false;
280        }
281        let per_apply = Self::admission_work_lower_bound(n, k, d);
282        let total = per_apply.saturating_mul(cg_iters as u128);
283        total >= Self::MATVEC_OFFLOAD_FLOPS_MIN
284    }
285}
286
287/// The aspirational single-GPU design-row throughput the #1412 decision gate is
288/// supposed to establish for the LLM-shape batched-Cholesky + tile-GEMM fit
289/// pipeline: 100 000 design rows processed per wall-clock second per device.
290///
291/// The original gate *claimed* this number without ever measuring it. The
292/// honest contract is the other way around: a benchmark
293/// (`examples/throughput_1412.rs`) measures the true rows/sec on a real device,
294/// and [`GpuThroughputVerdict::from_measurement`] reports whether the measured
295/// value meets the target — the verdict is a *function of the measurement*, not
296/// a hardcoded assertion. See `tests/owed_1412.rs`.
297pub const GPU_THROUGHPUT_TARGET_ROWS_PER_SEC: f64 = 100_000.0;
298
299/// Outcome of comparing a *measured* GPU throughput against the target. The
300/// only way to construct one is [`Self::from_measurement`], so a verdict can
301/// never assert a target that was not actually established by a measurement.
302#[derive(Clone, Copy, Debug, PartialEq)]
303pub struct GpuThroughputVerdict {
304    /// The measured design-rows-per-second on the device under test.
305    pub measured_rows_per_sec: f64,
306    /// The target the measurement is compared against.
307    pub target_rows_per_sec: f64,
308    /// `measured / target`. ≥ 1.0 means the target was established.
309    pub fraction_of_target: f64,
310    /// True iff `measured_rows_per_sec >= target_rows_per_sec`.
311    pub meets_target: bool,
312}
313
314impl GpuThroughputVerdict {
315    /// Build a verdict from a measured throughput against
316    /// [`GPU_THROUGHPUT_TARGET_ROWS_PER_SEC`]. A non-finite or non-positive
317    /// measurement can never meet the target (it is not a usable measurement).
318    #[inline]
319    pub fn from_measurement(measured_rows_per_sec: f64) -> Self {
320        Self::from_measurement_against(measured_rows_per_sec, GPU_THROUGHPUT_TARGET_ROWS_PER_SEC)
321    }
322
323    /// Build a verdict against an explicit target (used by tests that probe the
324    /// comparison logic without depending on the global target constant).
325    #[inline]
326    pub fn from_measurement_against(measured_rows_per_sec: f64, target_rows_per_sec: f64) -> Self {
327        let usable = measured_rows_per_sec.is_finite() && measured_rows_per_sec > 0.0;
328        let fraction_of_target = if usable && target_rows_per_sec > 0.0 {
329            measured_rows_per_sec / target_rows_per_sec
330        } else {
331            0.0
332        };
333        Self {
334            measured_rows_per_sec,
335            target_rows_per_sec,
336            fraction_of_target,
337            meets_target: usable && measured_rows_per_sec >= target_rows_per_sec,
338        }
339    }
340}
341
342/// Why a Stage-3 encode deployment decision could not be made from a real device
343/// measurement (#988, #1412). Each variant is a state in which the
344/// `100_000` rows/sec/GPU target was neither established NOR refuted on a
345/// device — the decision is blocked on hardware, not green-washed from a CPU
346/// proxy.
347#[derive(Clone, Copy, Debug, PartialEq, Eq)]
348pub enum EncodeDecisionBlocked {
349    /// No CUDA device on this host: the exact encode could not be measured on a
350    /// device at all (a CPU rate cannot substitute — that was the #1412 defect).
351    NoDevice,
352    /// A device is present but there is no device-resident *exact-encode* kernel,
353    /// so the FULL per-row encode cannot be measured on the device. (The resident
354    /// normal-equations solve in [`crate::encode_throughput`] is only ONE
355    /// component of the encode, not the encode; a component measurement cannot
356    /// decide the encode surrogate question — #988.)
357    NoDeviceEncodeKernel,
358    /// A device is present and a measurement was attempted, but the device path
359    /// did not engage (false routing) — refused rather than reported as a pass.
360    DeviceNotEngaged,
361}
362
363/// Tri-state Stage-3 encode deployment / amortized-surrogate decision
364/// (#988, #1412).
365///
366/// The decision the throughput gate exists to make is empirical: does the EXACT
367/// per-row encode clear the `100_000` rows/sec/GPU deployment target on a real
368/// device? Only a real device measurement can answer it:
369///   * [`Self::Met`] — a device measurement CLEARED the target: ship the exact
370///     encode; the certified amortized surrogate is NOT needed.
371///   * [`Self::Unmet`] — a device measurement MISSED the target: the certified
372///     amortized surrogate becomes justified.
373///   * [`Self::Undetermined`] — no device measurement is available. The decision
374///     is BLOCKED on hardware; it is neither "surrogate unneeded" nor "surrogate
375///     justified".
376///
377/// The critical anti-green-wash property (#1412): there is NO constructor that
378/// takes a CPU rate. A CPU measurement, however fast, can never move the decision
379/// out of [`Self::Undetermined`]. Projecting a CPU rate through an assumed
380/// CPU→GPU factor to declare the target met was the exact #1412 defect and is
381/// structurally impossible here — [`Self::Met`] / [`Self::Unmet`] come only from
382/// [`Self::from_device_measurement`] with `engaged == true`.
383#[derive(Clone, Copy, Debug, PartialEq)]
384pub enum EncodeDeploymentDecision {
385    /// A device measurement established the deployment target.
386    Met {
387        /// The measured device rows/sec that cleared the target.
388        measured_rows_per_sec: f64,
389        /// The target it was compared against.
390        target_rows_per_sec: f64,
391    },
392    /// A device measurement fell short of the deployment target.
393    Unmet {
394        /// The measured device rows/sec that missed the target.
395        measured_rows_per_sec: f64,
396        /// The target it was compared against.
397        target_rows_per_sec: f64,
398    },
399    /// No device measurement is available; the decision is blocked on hardware.
400    Undetermined {
401        /// Why no device measurement could be made.
402        reason: EncodeDecisionBlocked,
403    },
404}
405
406impl EncodeDeploymentDecision {
407    /// The ONLY path to a `Met`/`Unmet` decision: a device measurement that
408    /// actually engaged the device and produced a usable rate. `engaged == false`
409    /// (false routing / CPU decline) or a non-finite / non-positive rate yields
410    /// [`Self::Undetermined`] — never a fabricated pass or fail.
411    #[must_use]
412    pub fn from_device_measurement(engaged: bool, measured_rows_per_sec: f64) -> Self {
413        Self::from_device_measurement_against(
414            engaged,
415            measured_rows_per_sec,
416            GPU_THROUGHPUT_TARGET_ROWS_PER_SEC,
417        )
418    }
419
420    /// [`Self::from_device_measurement`] against an explicit target (for tests
421    /// that probe the decision logic without the global target constant).
422    #[must_use]
423    pub fn from_device_measurement_against(
424        engaged: bool,
425        measured_rows_per_sec: f64,
426        target_rows_per_sec: f64,
427    ) -> Self {
428        let usable = measured_rows_per_sec.is_finite() && measured_rows_per_sec > 0.0;
429        if !engaged || !usable {
430            return Self::Undetermined {
431                reason: EncodeDecisionBlocked::DeviceNotEngaged,
432            };
433        }
434        if measured_rows_per_sec >= target_rows_per_sec {
435            Self::Met {
436                measured_rows_per_sec,
437                target_rows_per_sec,
438            }
439        } else {
440            Self::Unmet {
441                measured_rows_per_sec,
442                target_rows_per_sec,
443            }
444        }
445    }
446
447    /// Construct the blocked decision for a host that cannot measure the exact
448    /// encode on a device. This is the honest CPU-only / no-device-kernel outcome
449    /// — the deployment target is left undetermined rather than projected.
450    #[must_use]
451    pub fn blocked(reason: EncodeDecisionBlocked) -> Self {
452        Self::Undetermined { reason }
453    }
454
455    /// True ONLY when a device measurement cleared the target: the exact encode
456    /// ships and no surrogate is built. Never true from a CPU proxy.
457    #[must_use]
458    pub fn surrogate_unneeded(&self) -> bool {
459        matches!(self, Self::Met { .. })
460    }
461
462    /// True ONLY when a device measurement missed the target: the certified
463    /// amortized surrogate becomes justified. Never true without a measurement.
464    #[must_use]
465    pub fn surrogate_justified(&self) -> bool {
466        matches!(self, Self::Unmet { .. })
467    }
468
469    /// True when no device measurement is available and the decision is blocked
470    /// on hardware (neither [`Self::surrogate_unneeded`] nor
471    /// [`Self::surrogate_justified`]).
472    #[must_use]
473    pub fn is_undetermined(&self) -> bool {
474        matches!(self, Self::Undetermined { .. })
475    }
476}
477
478/// Which `(response, link)` family the Stage 3.3 device-resident PIRLS loop
479/// can evaluate without going through the Level-B raw-body NVRTC path.
480///
481/// Mirrors `PirlsRowFamily::ALL` at the policy layer so the predicate stays
482/// linkable from the CPU PIRLS entry without dragging a Linux-only enum into
483/// every host compilation unit.
484#[derive(Clone, Copy, Debug, Eq, PartialEq)]
485pub enum PirlsLoopFamilyKind {
486    BernoulliLogit,
487    BernoulliProbit,
488    BernoulliCLogLog,
489    PoissonLog,
490    GaussianIdentity,
491    GammaLog,
492}
493
494#[derive(Clone, Copy, Debug, Eq, PartialEq)]
495pub enum PirlsLoopCurvatureKind {
496    Fisher,
497    Observed,
498}
499
500/// Inputs to [`should_run_reml_outer_on_device`]. The admission predicate
501/// for routing the *outer* REML BFGS-over-ρ loop onto a fully device-resident
502/// driver (rather than the host orchestrator that hops out per step).
503///
504/// Fields are intentionally lifted from data the CPU REML entry has on hand
505/// before it touches the seed generator or the inner P-IRLS loop, so the
506/// admission check is allocation-free and can short-circuit before any
507/// device call.
508#[derive(Clone, Copy, Debug)]
509pub struct RemlOuterAdmission {
510    /// Active design rows (post-transform).
511    pub n: usize,
512    /// Active design columns / penalised-Hessian dimension.
513    pub p: usize,
514    /// Number of smoothing parameters ρ the outer BFGS optimises over.
515    pub num_rho: usize,
516    /// Inner family / link pair the device-resident PIRLS loop can evaluate.
517    /// `None` means the family does not map onto the six JIT-cached row
518    /// kernels — the outer loop must stay on the host orchestrator because
519    /// the inner step would already hop out anyway.
520    pub family: Option<PirlsLoopFamilyKind>,
521    /// Curvature surface the inner loop will use; tied to `family` via
522    /// `pirls_loop_curvature_for`.
523    pub curvature: PirlsLoopCurvatureKind,
524    /// True when the CUDA runtime is initialised on this host.
525    pub gpu_available: bool,
526}
527
528/// Inputs to [`should_use_gpu_pirls_loop`]. Each field comes from data the
529/// CPU PIRLS entry has on hand before it touches the eigendecomposition
530/// engine, so the admission check itself is allocation-free and can short-
531/// circuit before any heavy work happens.
532#[derive(Clone, Copy, Debug)]
533pub struct PirlsLoopAdmission {
534    /// Number of rows in the active (post-transform) design matrix.
535    pub n: usize,
536    /// Number of columns in the active design (i.e. `p` of `Xᵀ X`).
537    pub p: usize,
538    /// `Some(_)` when the inner family maps onto one of the six JIT-cached
539    /// `PirlsRowFamily` variants; `None` for custom families that still
540    /// require Stage 6 Level B and have not yet been admitted here.
541    pub family: Option<PirlsLoopFamilyKind>,
542    /// Curvature surface the inner loop will use; the GPU loop has Fisher +
543    /// Observed kernels, anything else (e.g. expected-projection surrogates)
544    /// is not admitted.
545    pub curvature: PirlsLoopCurvatureKind,
546    /// True when the CUDA runtime is initialised on this host (i.e.
547    /// `GpuRuntime::global().is_some()`).
548    pub gpu_available: bool,
549}
550
551impl GpuDispatchPolicy {
552    /// Minimum design column count for the device-resident inner/outer loops.
553    ///
554    /// Below this width the per-iteration `XᵀWX + Cholesky` is dominated by
555    /// launch latency and PCIe staging rather than arithmetic, so the host LM
556    /// loop (which populates the full `PirlsResult` surface as a free
557    /// side-effect) is strictly cheaper. Shared by both the inner PIRLS and
558    /// outer REML admission predicates so they cannot drift apart.
559    pub const DEVICE_LOOP_MIN_P: usize = 32;
560
561    /// Conservative admission predicate for routing
562    /// `fit_model_for_fixed_rho_with_adaptive_kkt` through the Stage 3.3
563    /// device-resident PIRLS loop instead of the CPU LM loop.
564    ///
565    /// The threshold is the dense `XᵀWX` work estimate, not row count alone:
566    /// LLM/SAE fits can have only a few thousand rows but thousands of columns,
567    /// so `2*n*p^2` already dwarfs launch/staging overhead. Smaller fits stay on
568    /// the CPU LM loop where the full `PirlsResult` surface (firth, EDF,
569    /// per-row weights, …) is already populated as a free side-effect of the
570    /// iteration.
571    pub const fn should_use_gpu_pirls_loop(&self, adm: PirlsLoopAdmission) -> bool {
572        if !adm.gpu_available {
573            return false;
574        }
575        if !self.dense_hessian_work_target_is_gpu(adm.n, adm.p) {
576            return false;
577        }
578        match adm.family {
579            Some(_) => true,
580            None => false,
581        }
582    }
583
584    /// Admission predicate for routing the outer REML BFGS-over-ρ loop onto
585    /// a device-resident driver that keeps the BFGS state (ρ, gradient,
586    /// Hessian approx) on-device and only downloads the per-step scalar
587    /// metrics (objective value, gradient norm, convergence flag).
588    ///
589    /// The dense-work threshold piggybacks on the existing inner-PIRLS admission
590    /// predicate because the device-resident outer loop calls
591    /// `pirls_loop_on_stream` per step and must not pay the host hop for small
592    /// fits the inner loop would have rejected anyway. The
593    /// `num_rho ≥ 2` floor rules out the trivial single-smoother case where
594    /// host orchestration is already negligible and the device BFGS state
595    /// (one length-`num_rho` gradient + a `num_rho × num_rho` Hessian
596    /// approx) collapses to a couple of scalars not worth keeping on device.
597    pub const fn should_run_reml_outer_on_device(&self, adm: RemlOuterAdmission) -> bool {
598        if !adm.gpu_available {
599            return false;
600        }
601        if !self.dense_hessian_work_target_is_gpu(adm.n, adm.p) {
602            return false;
603        }
604        if adm.num_rho < 2 {
605            return false;
606        }
607        match adm.family {
608            Some(_) => true,
609            None => false,
610        }
611    }
612}
613
614#[cfg(test)]
615mod refinement_policy_tests {
616    use super::*;
617
618    #[test]
619    fn refinement_policy_admits_large_p() {
620        let pol = GpuDispatchPolicy::default();
621        // Default policy is Refinement; large p should be admitted.
622        assert!(pol.iterative_refinement_should_attempt(512));
623        assert!(pol.iterative_refinement_should_attempt(GpuDispatchPolicy::REFINEMENT_MIN_P));
624    }
625
626    #[test]
627    fn refinement_policy_rejects_small_p() {
628        let pol = GpuDispatchPolicy::default();
629        assert!(!pol.iterative_refinement_should_attempt(GpuDispatchPolicy::REFINEMENT_MIN_P - 1));
630        assert!(!pol.iterative_refinement_should_attempt(0));
631    }
632
633    #[test]
634    fn off_policy_never_attempts_refinement() {
635        let pol = GpuDispatchPolicy {
636            mixed_precision: GpuMixedPrecisionPolicy::Off,
637            ..Default::default()
638        };
639        assert!(!pol.iterative_refinement_should_attempt(1024));
640    }
641
642    #[test]
643    fn never_policy_never_attempts_refinement() {
644        let pol = GpuDispatchPolicy {
645            mixed_precision: GpuMixedPrecisionPolicy::Never,
646            ..Default::default()
647        };
648        assert!(!pol.iterative_refinement_should_attempt(1024));
649    }
650}
651
652#[cfg(test)]
653mod reduced_schur_matvec_offload_tests {
654    use super::*;
655
656    /// The LLM/SAE shape the whole #1017 Phase-1 re-keying targets: a few
657    /// thousand row blocks, a *wide* border (decoder atom count in the
658    /// thousands), a modest per-row frame depth, and a realistic CG budget.
659    /// The row-count gate (50k) and the dense-Direct flop floor both miss this
660    /// "thousands of tiny dense ops" shape; the work-amortised matvec gate must
661    /// fire on it.
662    #[test]
663    fn admits_llm_sae_matvec_shape() {
664        let pol = GpuDispatchPolicy::default();
665        // n≈2000 rows, k≈2048 atoms, M≈8 frame depth — n is far below the 50k
666        // row gate, yet the summed CG matvec work is large.
667        assert!(pol.reduced_schur_matvec_should_offload(
668            2_000,
669            2_048,
670            8,
671            GpuDispatchPolicy::MATVEC_OFFLOAD_MIN_CG_ITERS,
672        ));
673        // The same shape would be rejected by the row-count-style dense gate,
674        // confirming the re-keying is what admits it.
675        assert!(!pol.dense_hessian_work_target_is_gpu(2_000, 8));
676    }
677
678    /// Even with only a single conservative CG iteration the wide LLM border
679    /// clears the breakeven (the per-apply work alone is `2_000·(2·8·2_048 +
680    /// 8²) ≈ 6.6e7` flops > 1e7 by the conservative `n·(2·d·k + d²)` model;
681    /// the true `n·(4·d·k + d²)` arithmetic is ≈1.3e8),
682    /// so the gate is not relying on an inflated iteration count.
683    #[test]
684    fn admits_llm_shape_with_one_cg_iter() {
685        let pol = GpuDispatchPolicy::default();
686        assert!(pol.reduced_schur_matvec_should_offload(2_000, 2_048, 8, 1));
687    }
688
689    /// Tiny shapes where the host↔device transfer dominates must stay on the
690    /// CPU: a handful of rows, a narrow border, shallow frames. The summed
691    /// matvec work is orders of magnitude below the staging breakeven.
692    #[test]
693    fn rejects_tiny_shape_where_transfer_dominates() {
694        let pol = GpuDispatchPolicy::default();
695        assert!(!pol.reduced_schur_matvec_should_offload(
696            30,
697            8,
698            2,
699            GpuDispatchPolicy::MATVEC_OFFLOAD_MIN_CG_ITERS,
700        ));
701        // The 300×8 shape the production seam tests use as the "stay CPU"
702        // canary is rejected here too.
703        assert!(!pol.reduced_schur_matvec_should_offload(300, 8, 4, 16));
704    }
705
706    /// A narrow border (k below the device-loop floor) is rejected regardless
707    /// of how much row/iteration work is piled on: per-apply launch latency
708    /// dominates a sub-`DEVICE_LOOP_MIN_P` border.
709    #[test]
710    fn rejects_narrow_border_even_with_huge_row_count() {
711        let pol = GpuDispatchPolicy::default();
712        let narrow = GpuDispatchPolicy::DEVICE_LOOP_MIN_P - 1;
713        assert!(!pol.reduced_schur_matvec_should_offload(1_000_000, narrow, 64, 64));
714    }
715
716    /// Degenerate dimensions are never offloaded (no work, or no solve).
717    #[test]
718    fn rejects_degenerate_dimensions() {
719        let pol = GpuDispatchPolicy::default();
720        assert!(!pol.reduced_schur_matvec_should_offload(0, 2_048, 8, 8));
721        assert!(!pol.reduced_schur_matvec_should_offload(2_000, 0, 8, 8));
722        assert!(!pol.reduced_schur_matvec_should_offload(2_000, 2_048, 0, 8));
723        assert!(!pol.reduced_schur_matvec_should_offload(2_000, 2_048, 8, 0));
724    }
725
726    /// The gate is monotone in the CG budget: once a shape is admitted at a
727    /// given iteration count it stays admitted for any larger count (more
728    /// applies over the same resident frames only improves amortization), and
729    /// a borderline shape crosses the breakeven as iterations grow.
730    #[test]
731    fn monotone_in_cg_iters() {
732        let pol = GpuDispatchPolicy::default();
733        // A border at the floor with shallow frames and few rows: per-apply
734        // work ~ n·(2·d·k + d²). Choose a shape that is below breakeven at 1
735        // iter but above it once enough iterations accumulate.
736        let (n, k, d) = (200usize, GpuDispatchPolicy::DEVICE_LOOP_MIN_P, 4usize);
737        // per_apply ≈ 200·(2·4·32 + 16) = 200·272 = 54_400 flops.
738        assert!(!pol.reduced_schur_matvec_should_offload(n, k, d, 1));
739        // Once the summed work clears 1e7 the gate fires; ~184 iters here.
740        assert!(pol.reduced_schur_matvec_should_offload(n, k, d, 1_000));
741        // Monotonicity: admitted at 1_000 ⇒ admitted at every larger budget.
742        assert!(pol.reduced_schur_matvec_should_offload(n, k, d, 5_000));
743    }
744
745    /// The admission lower bound must stay strictly below the true per-apply
746    /// work `n·(4·d·k + d²)` for any non-degenerate cross-block shape (it drops
747    /// the transpose GEMV). Treating the lower bound as a flop count would
748    /// over-report device speedups, so this asserts the gap is real.
749    #[test]
750    fn admission_lower_bound_undercounts_actual_work() {
751        for &(n, k, d) in &[
752            (2_000usize, 2_048usize, 8usize),
753            (200, GpuDispatchPolicy::DEVICE_LOOP_MIN_P, 4),
754            (1, 1, 1),
755        ] {
756            let lower = GpuDispatchPolicy::admission_work_lower_bound(n, k, d);
757            // True per-apply work models the full forward+transpose GEMV pair
758            // plus the d×d solve: n·(4·d·k + d²).
759            let actual = (n as u128) * (4 * (d as u128) * (k as u128) + (d as u128) * (d as u128));
760            assert!(
761                lower < actual,
762                "admission lower bound {lower} must undercount actual work {actual} for ({n},{k},{d})"
763            );
764        }
765    }
766}
767
768#[cfg(test)]
769mod encode_deployment_decision_tests {
770    use super::*;
771
772    /// #1412 anti-green-wash core: a CPU rate can NEVER produce a `Met`/`Unmet`
773    /// decision. The only Met/Unmet constructor requires `engaged == true`; a
774    /// CPU-only host has no device measurement, so it can only ever be
775    /// `Undetermined`, no matter how fast the CPU is.
776    #[test]
777    fn cpu_rate_can_never_meet_or_refute_the_target() {
778        // Even a CPU rate a thousand times the target cannot certify the gate:
779        // there is simply no `from_cpu_measurement` — the type has no such door.
780        // The blocked constructor is the only CPU-side option.
781        let cpu_only = EncodeDeploymentDecision::blocked(EncodeDecisionBlocked::NoDevice);
782        assert!(cpu_only.is_undetermined());
783        assert!(!cpu_only.surrogate_unneeded());
784        assert!(!cpu_only.surrogate_justified());
785
786        // A "device" measurement that did not engage (false routing) is refused —
787        // it becomes Undetermined even with a huge rate.
788        let false_routed = EncodeDeploymentDecision::from_device_measurement(false, 1.0e9);
789        assert!(false_routed.is_undetermined());
790        assert!(!false_routed.surrogate_unneeded());
791    }
792
793    #[test]
794    fn engaged_measurement_decides_by_the_number() {
795        let target = GPU_THROUGHPUT_TARGET_ROWS_PER_SEC;
796        // Clears the target => Met => surrogate unneeded.
797        let met = EncodeDeploymentDecision::from_device_measurement(true, target * 2.0);
798        assert!(matches!(met, EncodeDeploymentDecision::Met { .. }));
799        assert!(met.surrogate_unneeded());
800        assert!(!met.surrogate_justified());
801        assert!(!met.is_undetermined());
802
803        // Misses the target => Unmet => surrogate justified.
804        let unmet = EncodeDeploymentDecision::from_device_measurement(true, target * 0.25);
805        assert!(matches!(unmet, EncodeDeploymentDecision::Unmet { .. }));
806        assert!(unmet.surrogate_justified());
807        assert!(!unmet.surrogate_unneeded());
808
809        // Exact boundary meets the target.
810        let boundary = EncodeDeploymentDecision::from_device_measurement(true, target);
811        assert!(boundary.surrogate_unneeded());
812    }
813
814    #[test]
815    fn engaged_but_non_usable_rate_is_undetermined_not_a_pass() {
816        for bad in [0.0, -1.0, f64::NAN, f64::INFINITY] {
817            let d = EncodeDeploymentDecision::from_device_measurement(true, bad);
818            assert!(
819                d.is_undetermined(),
820                "an engaged-but-unusable rate {bad} must be Undetermined, not a decision"
821            );
822            assert!(!d.surrogate_unneeded());
823            assert!(!d.surrogate_justified());
824        }
825    }
826
827    #[test]
828    fn blocked_reasons_are_all_undetermined() {
829        for reason in [
830            EncodeDecisionBlocked::NoDevice,
831            EncodeDecisionBlocked::NoDeviceEncodeKernel,
832            EncodeDecisionBlocked::DeviceNotEngaged,
833        ] {
834            let d = EncodeDeploymentDecision::blocked(reason);
835            assert!(d.is_undetermined());
836            assert!(!d.surrogate_unneeded());
837            assert!(!d.surrogate_justified());
838        }
839    }
840}