Skip to main content

gam_gpu/
encode_throughput.rs

1//! Measured device-resident encode throughput for the SAE/LLM batched-solve
2//! shape (#1412, #988, #1017 Phase-3).
3//!
4//! ## Why this module exists
5//!
6//! The historical throughput "decision gate" (#1412) asserted a `100_000`
7//! rows/sec/GPU deployment target **without ever measuring a device**. Its
8//! successor still keyed the deployment decision on a *CPU* measurement scaled
9//! by a hardcoded `CPU_TO_GPU_SCALING = 100.0` fudge factor — so passing the
10//! gate established nothing about real GPU throughput. #988 closed
11//! `COMPLETED` while the maintainer's own follow-up confirmed the GPU
12//! steady-state encode rate had never been measured.
13//!
14//! This module makes the measurement real and *testable as a library function*
15//! (the prior real benchmark lived only in `examples/throughput_1412.rs`, which
16//! nothing in CI ran or asserted). [`measure_resident_solve_throughput`] runs
17//! the production IRLS inner step — upload `X` once, then repeatedly solve the
18//! penalized normal equations `(XᵀWX + ridge·I)β = rhs` with the `p×p` Gram and
19//! its Cholesky factor kept DEVICE-RESIDENT, downloading only the `p`-vector
20//! `β` — on the real device, and reports the measured design-rows/sec.
21//!
22//! ## Fail-loud, never false-route
23//!
24//! The single recurring failure mode this guards against is *false GPU
25//! routing*: claiming a device measurement while the work silently ran on the
26//! CPU. [`ResidentSolveThroughput::engaged`] is `true` only when
27//! [`ResidentDesignGram::try_new`] actually staged `X` on the device AND every
28//! timed solve returned a device result. If the device path declines or fails
29//! mid-measurement, `engaged` is `false` and `measured_rows_per_sec` is left at
30//! `0.0` — a non-measurement that [`GpuThroughputVerdict`] can never report as
31//! meeting the target. There is no CPU fallback inside the measurement: a
32//! caller that wants the CPU oracle runs it separately for parity.
33
34use std::hint::black_box;
35use std::time::{Duration, Instant};
36
37use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
38
39use super::linalg_dispatch::ResidentDesignGram;
40use super::policy::{GpuThroughputVerdict, GPU_THROUGHPUT_TARGET_ROWS_PER_SEC};
41
42/// A representative LLM/SAE batched-solve work cell: `n` design rows, `p` wide
43/// decoder border. (`d`, the per-atom reduced-Schur block size, is fixed by the
44/// term and does not enter the resident-solve throughput.)
45#[derive(Clone, Copy, Debug)]
46pub struct EncodeShape {
47    /// Human-readable label for reporting.
48    pub label: &'static str,
49    /// Design rows pushed through the device per fit.
50    pub n: usize,
51    /// Decoder-border width (the resident Gram is `p×p`).
52    pub p: usize,
53}
54
55/// The canonical qwen/olmo-scale SAE residual-block shapes (matches the
56/// `examples/throughput_1412.rs` workload so the library measurement and the
57/// example agree).
58pub const CANONICAL_ENCODE_SHAPES: &[EncodeShape] = &[
59    EncodeShape {
60        label: "sae-2k-2048",
61        n: 2_000,
62        p: 2_048,
63    },
64    EncodeShape {
65        label: "sae-4k-4096",
66        n: 4_000,
67        p: 4_096,
68    },
69    EncodeShape {
70        label: "sae-8k-1024",
71        n: 8_000,
72        p: 1_024,
73    },
74];
75
76/// Outcome of measuring the device-resident penalized-solve throughput for one
77/// [`EncodeShape`].
78#[derive(Clone, Copy, Debug)]
79pub struct ResidentSolveThroughput {
80    /// The shape that was measured.
81    pub shape: EncodeShape,
82    /// `true` iff `X` was staged on the device AND every timed solve returned a
83    /// device result. `false` means the device path declined or failed — the
84    /// number below is **not** a device measurement.
85    pub engaged: bool,
86    /// Measured design-rows/sec for the resident solve, or `0.0` when the
87    /// device path did not engage (a non-measurement).
88    pub measured_rows_per_sec: f64,
89    /// The verdict comparing `measured_rows_per_sec` against
90    /// [`GPU_THROUGHPUT_TARGET_ROWS_PER_SEC`].
91    pub verdict: GpuThroughputVerdict,
92}
93
94/// Deterministic LCG in `[-1, 1)` — no `rand` dependency, fully reproducible
95/// across runs so the measured fixture is stable.
96fn lcg(state: &mut u64) -> f64 {
97    *state = state
98        .wrapping_mul(6364136223846793005)
99        .wrapping_add(1442695040888963407);
100    (*state >> 11) as f64 / (1u64 << 53) as f64 * 2.0 - 1.0
101}
102
103/// Build a deterministic `n×p` design fixture for the throughput measurement.
104fn planted_design(n: usize, p: usize, seed: u64) -> Array2<f64> {
105    let mut s = seed;
106    Array2::from_shape_fn((n, p), |_| lcg(&mut s) * 0.05)
107}
108
109/// Measure the device-resident penalized-normal-equations solve throughput for
110/// one shape: upload `X` once, then time `reps` solves that cross only `w`
111/// (H2D), `rhs` (H2D, fixed), and `β` (D2H) — the production IRLS inner step.
112///
113/// `reps` is the number of timed solves; `w` is perturbed per rep so each solve
114/// is genuine work, mirroring an IRLS weight update. Returns a
115/// [`ResidentSolveThroughput`] whose `engaged` flag is the false-routing guard:
116/// on a CPU-only host (or if the device declines) it is `false` and the rate is
117/// `0.0`.
118#[must_use]
119pub fn measure_resident_solve_throughput(shape: EncodeShape, reps: usize) -> ResidentSolveThroughput {
120    let EncodeShape { n, p, .. } = shape;
121    let not_engaged = |shape| ResidentSolveThroughput {
122        shape,
123        engaged: false,
124        measured_rows_per_sec: 0.0,
125        verdict: GpuThroughputVerdict::from_measurement(0.0),
126    };
127    if n == 0 || p == 0 || reps == 0 {
128        return not_engaged(shape);
129    }
130
131    let x = planted_design(n, p, 0x1412_a100_dead_beef);
132    let w = {
133        let mut s = 0x988_5ae_e0c0_de01u64;
134        Array1::from_shape_fn(n, |_| lcg(&mut s).abs() + 1e-3)
135    };
136    let rhs = Array1::from_shape_fn(p, |j| ((j as f64 + 1.0) * 0.03).cos());
137    let ridge = 1e-3_f64;
138
139    // Stage X once. `None` => no device / shape below the Gram threshold => not
140    // a device measurement.
141    let handle = match ResidentDesignGram::try_new(x.view()) {
142        Some(h) => h,
143        None => return not_engaged(shape),
144    };
145
146    // Warm the resident solve (allocations, kernel handles) outside the timer;
147    // if even the warm solve declines, the device path is not usable here.
148    if handle.solve_normal_equations(w.view(), rhs.view(), ridge).is_none() {
149        return not_engaged(shape);
150    }
151
152    let mut total = Duration::ZERO;
153    for r in 0..reps {
154        let wr = Array1::from_shape_fn(n, |i| (w[i] + 1e-3 * (r as f64)).abs());
155        let start = Instant::now();
156        match handle.solve_normal_equations(wr.view(), rhs.view(), ridge) {
157            Some(beta) => {
158                black_box(beta);
159            }
160            // A mid-measurement decline means the timed region is no longer a
161            // pure device measurement — refuse to report it as one.
162            None => return not_engaged(shape),
163        }
164        total += start.elapsed();
165    }
166
167    let secs = total.as_secs_f64() / reps as f64;
168    let measured_rows_per_sec = if secs > 0.0 { n as f64 / secs } else { 0.0 };
169    ResidentSolveThroughput {
170        shape,
171        engaged: measured_rows_per_sec > 0.0,
172        measured_rows_per_sec,
173        verdict: GpuThroughputVerdict::from_measurement(measured_rows_per_sec),
174    }
175}
176
177/// CPU oracle for the same penalized normal-equations solve, used for parity:
178/// `(XᵀWX + ridge·I)β = rhs` solved by a host Cholesky. This is the definition
179/// of truth the device solve must match (up to IEEE-754 reduction order).
180#[must_use]
181pub fn cpu_oracle_normal_equations_solve(
182    x: ArrayView2<'_, f64>,
183    w: ArrayView1<'_, f64>,
184    rhs: ArrayView1<'_, f64>,
185    ridge: f64,
186) -> Array1<f64> {
187    let (n, p) = x.dim();
188    assert_eq!(w.len(), n, "w must have one entry per design row");
189    assert_eq!(rhs.len(), p, "rhs must have one entry per border column");
190
191    // Gram = Xᵀ diag(w) X + ridge·I, formed in f64 as (√w⊙X)ᵀ(√w⊙X) via the
192    // BLAS-backed `dot` (the scalar triple loop is O(n·p²) and dominates the
193    // oracle at p in the thousands). Folding √w into both factors keeps the
194    // weighting exact: row i contributes wᵢ·xᵢₐ·xᵢᵦ as (√wᵢxᵢₐ)(√wᵢxᵢᵦ).
195    let mut xw = x.to_owned();
196    for i in 0..n {
197        let sw = w[i].sqrt();
198        for a in 0..p {
199            xw[[i, a]] *= sw;
200        }
201    }
202    let mut gram = xw.t().dot(&xw);
203    for j in 0..p {
204        gram[[j, j]] += ridge;
205    }
206
207    // Cholesky: gram = L Lᵀ (lower), then solve L y = rhs, Lᵀ β = y.
208    let mut l = Array2::<f64>::zeros((p, p));
209    for j in 0..p {
210        let mut diag = gram[[j, j]];
211        for s in 0..j {
212            diag -= l[[j, s]] * l[[j, s]];
213        }
214        let ljj = diag.max(0.0).sqrt();
215        l[[j, j]] = ljj;
216        for i in (j + 1)..p {
217            let mut off = gram[[i, j]];
218            for s in 0..j {
219                off -= l[[i, s]] * l[[j, s]];
220            }
221            l[[i, j]] = off / ljj;
222        }
223    }
224    let mut y = rhs.to_owned();
225    for i in 0..p {
226        let mut acc = y[i];
227        for s in 0..i {
228            acc -= l[[i, s]] * y[s];
229        }
230        y[i] = acc / l[[i, i]];
231    }
232    let mut beta = y;
233    for i in (0..p).rev() {
234        let mut acc = beta[i];
235        for s in (i + 1)..p {
236            acc -= l[[s, i]] * beta[s];
237        }
238        beta[i] = acc / l[[i, i]];
239    }
240    beta
241}
242
243/// The deployment target, re-exported so callers measuring throughput do not
244/// have to import the policy module directly.
245pub const DEPLOYMENT_TARGET_ROWS_PER_SEC: f64 = GPU_THROUGHPUT_TARGET_ROWS_PER_SEC;