Skip to main content

gam_sae/sparse_dict/
scoring_gpu.rs

1//! GPU score-block kernel for the collapsed-linear-lane router (#1026).
2//!
3//! The collapsed linear lane ([`crate::sparse_dict`]) scales a linear SAE
4//! dictionary to `K ≈ 32_000` atoms by routing each row against the WHOLE
5//! dictionary one atom-tile at a time and keeping only the top-`s` atoms online.
6//! That route step — `scores[r][a] = Σ_c x[r][c]·decoder[a][c]` over a
7//! `rows × tile × P` block — is the dominant cost of a fit (a single fit at
8//! `K ≈ 32k` is the measured 1e4–1e6× hardware gap the issue tracks) and is the
9//! embarrassingly-parallel shape a GPU exists for.
10//!
11//! # What this offloads (and what it does NOT)
12//!
13//! This computes ONE atom-tile's `rows × tile` score block on the device,
14//! exactly the block the CPU [`super::scoring::score_row_tile`] folds into the
15//! per-row online top-`s` selectors. The selection logic itself
16//! ([`super::scoring::TopSSelector`]) stays single-sourced on the CPU: the
17//! device returns the score block, the host folds it into the SAME selectors,
18//! and discards it. The minibatch router [`route_minibatch_required`] walks the
19//! whole `K`-wide dictionary in atom-column tiles (each launch's block capped at
20//! `GPU_ROUTE_TILE_ELEMS`), so peak host/device score memory is `rows × tile`,
21//! **independent of `K`** — the lane's no-`N×K` memory discipline is preserved
22//! on the device exactly as on the CPU; the GPU just does the `O(rows·tile·P)`
23//! multiply-accumulate that dominates.
24//!
25//! # Bit-exact parity (the gate, not a tolerance)
26//!
27//! The CPU oracle accumulates `acc += x[c]·d[c]` as SEPARATE f32 multiply then
28//! f32 add (Rust emits no fused multiply-add for `a*b+c` unless `f32::mul_add`
29//! is called, and `-ffp-contract` is off), in ascending `c` order. NVRTC
30//! defaults to `--fmad=true`, which contracts `a*b+c` into a single-rounding
31//! FMA — a ~1 ULP difference that can flip a near-tie top-`s` selection and make
32//! the routed support, and hence the whole fit, diverge from the CPU oracle.
33//!
34//! So the kernel forces SEPARATE rounding with `__fmul_rn` + `__fadd_rn`, in the
35//! SAME ascending-`c` order, giving a score block that is **bit-for-bit**
36//! identical to the CPU `score_row_tile` (every `f32` equal under `to_bits`).
37//! Because the scores are identical, the CPU selector fed device scores produces
38//! the IDENTICAL routed support — parity is exact by construction, not bounded
39//! by a tolerance.
40
41#![cfg(target_os = "linux")]
42
43use ndarray::ArrayView2;
44
45/// The bit-exact-parity NVRTC kernel. One thread per `(row, atom)` output;
46/// accumulates over `P` columns in ascending order with separate-rounding f32
47/// ops so the result matches the CPU sequential `acc += x·d` to the bit.
48///
49/// `PP` (the column count) is baked in as a `#define` so the inner loop is a
50/// fixed trip count (matching the other NVRTC kernels in this repo, which
51/// monomorphise their shape macros for a pure `compile_ptx`).
52pub const SCORE_BLOCK_KERNEL_SOURCE: &str = r#"
53extern "C" __global__
54void sparse_dict_score_block(
55    const float* __restrict__ rows,    // [n_rows * PP] row-major
56    const float* __restrict__ atoms,   // [n_atoms * PP] row-major (decoder tile)
57    int n_rows,
58    int n_atoms,
59    float* __restrict__ scores)        // [n_rows * n_atoms] row-major
60{
61  long long idx = (long long)blockIdx.x * blockDim.x + threadIdx.x;
62  long long total = (long long)n_rows * (long long)n_atoms;
63  if (idx >= total) return;
64  int r = (int)(idx / n_atoms);
65  int a = (int)(idx % n_atoms);
66  const float* xr = rows  + (long long)r * PP;
67  const float* da = atoms + (long long)a * PP;
68  // SEPARATE-rounding accumulation in ascending c — NO fused multiply-add, so
69  // this is bit-identical to the CPU `acc += x[c]*d[c]` reference order.
70  float acc = 0.0f;
71  for (int c = 0; c < PP; ++c) {
72    float prod = __fmul_rn(xr[c], da[c]);
73    acc = __fadd_rn(acc, prod);
74  }
75  scores[idx] = acc;
76}
77"#;
78
79/// Prepend the `PP` shape macro so the NVRTC compile is a pure `compile_ptx`
80/// (mirrors `sae_rowjet::softmax_kernel_source` / `arrow_schur_nvrtc`).
81#[must_use]
82pub fn score_block_kernel_source(p: usize) -> String {
83    format!("#define PP {p}\n{SCORE_BLOCK_KERNEL_SOURCE}")
84}
85
86/// CPU reference for the score block: `scores[r*n_atoms + a] = Σ_c
87/// rows[r][c]·atoms[a][c]`, accumulated in ascending `c` with separate f32
88/// rounding — the SAME arithmetic [`super::scoring::score_row_tile`] runs
89/// per atom. This is the parity oracle the device kernel is locked against.
90#[must_use]
91pub fn score_block_cpu(rows: ArrayView2<'_, f32>, atoms: ArrayView2<'_, f32>) -> Vec<f32> {
92    let n_rows = rows.nrows();
93    let n_atoms = atoms.nrows();
94    let p = rows.ncols();
95    assert_eq!(p, atoms.ncols(), "score_block_cpu: P mismatch rows vs atoms");
96    let mut scores = vec![0.0f32; n_rows * n_atoms];
97    for r in 0..n_rows {
98        let xr = rows.row(r);
99        for a in 0..n_atoms {
100            let da = atoms.row(a);
101            let mut acc = 0.0f32;
102            for c in 0..p {
103                // separate mul then add — matches the kernel's __fmul_rn/__fadd_rn
104                acc += xr[c] * da[c];
105            }
106            scores[r * n_atoms + a] = acc;
107        }
108    }
109    scores
110}
111
112/// Minimum score-block element count (`n_rows · n_atoms`) below which the device
113/// launch is not worth its fixed cost (probe + H2D + D2H). Below this the CPU
114/// reference is used. Tuned to the same genus as the other SAE device floors
115/// (`sae_rowjet::DEVICE_ROW_THRESHOLD`).
116pub const DEVICE_SCORE_BLOCK_MIN_ELEMS: usize = 1 << 20; // ~1M MACs/tile minimum
117
118/// Which path produced a score block. Returned by the fail-loud entry point so
119/// callers (and the parity test) can ASSERT the device engaged rather than
120/// silently falling back — the #1026/#1551 'GPU 0%' failure mode.
121#[derive(Debug, Clone, Copy, PartialEq, Eq)]
122pub enum ScoreBlockPath {
123    /// The NVRTC `sparse_dict_score_block` kernel ran on the device.
124    Device,
125    /// The CPU `score_block_cpu` reference ran.
126    Cpu,
127}
128
129/// Fail-loud, residency-aware score-block entry point (#1026 scale-K lane).
130///
131/// Honours the process-wide [`gam_gpu::GpuMode`] contract: under
132/// [`gam_gpu::GpuMode::Required`] a missing CUDA runtime, an NVRTC/arch compile
133/// failure, a launch fault, or a block below the device break-even all return
134/// `Err` instead of silently degrading to the CPU. [`gam_gpu::GpuMode::Auto`]
135/// uses the device when admitted and the block clears the break-even, else the
136/// CPU; [`gam_gpu::GpuMode::Off`] always the CPU. The returned [`ScoreBlockPath`]
137/// reports which path actually ran.
138///
139/// Both paths produce a BIT-IDENTICAL `f32` score block (see module docs), so
140/// the routed top-`s` support is identical whichever path runs.
141///
142/// # Errors
143/// Returns [`gam_gpu::GpuError`] when [`gam_gpu::GpuMode::Required`] is set but
144/// the device path cannot run.
145pub fn score_block_required(
146    rows: ArrayView2<'_, f32>,
147    atoms: ArrayView2<'_, f32>,
148    mode: gam_gpu::GpuMode,
149) -> Result<(Vec<f32>, ScoreBlockPath), gam_gpu::GpuError> {
150    use gam_gpu::GpuMode;
151
152    let n_rows = rows.nrows();
153    let n_atoms = atoms.nrows();
154    let elems = n_rows.saturating_mul(n_atoms);
155
156    if mode == GpuMode::Off {
157        return Ok((score_block_cpu(rows, atoms), ScoreBlockPath::Cpu));
158    }
159
160    let below_breakeven = elems < DEVICE_SCORE_BLOCK_MIN_ELEMS;
161    if mode == GpuMode::Required && below_breakeven {
162        return Err(gam_gpu::gpu_err!(
163            "sparse_dict score-block GpuMode::Required: block of {n_rows}×{n_atoms} \
164             = {elems} elems is below the device launch break-even \
165             (DEVICE_SCORE_BLOCK_MIN_ELEMS={DEVICE_SCORE_BLOCK_MIN_ELEMS}); refusing \
166             to silently run on the CPU"
167        ));
168    }
169    if !below_breakeven {
170        match device::score_block_device(rows, atoms) {
171            Ok(out) => return Ok((out, ScoreBlockPath::Device)),
172            Err(err) => {
173                if mode == GpuMode::Required {
174                    return Err(err);
175                }
176                // Auto: fall through to the CPU.
177            }
178        }
179    }
180
181    Ok((score_block_cpu(rows, atoms), ScoreBlockPath::Cpu))
182}
183
184/// Peak score elements per device launch for the tiled GPU router. The router
185/// NEVER materialises the whole `m × K` block: it walks `K` in atom-column tiles
186/// sized so each launch's `m × cols` block stays under this cap (~2M f32 ≈ 8 MB
187/// host + 8 MB device), then discards it after folding. This keeps peak score
188/// memory bounded **independent of `K`** — the same discipline the CPU lane
189/// ([`super::scoring::top_s_online`]) keeps with its `rows × tile` column tiles —
190/// so a `K ≈ 32_000` fit does not balloon a `device alloc` linearly in `K`.
191const GPU_ROUTE_TILE_ELEMS: usize = 1 << 21;
192
193/// Route a whole minibatch of rows against the full decoder, returning each
194/// row's top-`s` `(atom, score)` selection — BIT-IDENTICAL to calling
195/// [`super::scoring::top_s_online`] per row, but with the score block computed on
196/// the device (in `K`-tiled launches) when admitted.
197///
198/// The selection ([`super::scoring::TopSSelector`]) is single-sourced on the CPU
199/// and fed the device scores in ascending atom order. `TopSSelector` keeps the
200/// top-`s` by `(|score| desc, atom asc)` — a strict total order on the unique
201/// atom indices — so the selected set is independent of the order or the tiling
202/// in which candidates are offered. Combined with bit-identical scores (the
203/// kernel forbids FMA contraction), the routed support matches the CPU oracle
204/// **exactly**, whichever path and whatever GPU tile width computed the block.
205///
206/// Memory: the `m × K` block is never formed whole — `K` is walked in tiles of
207/// at most `GPU_ROUTE_TILE_ELEMS / m` atom-columns, each launched, folded, and
208/// discarded, so peak score memory is `m × tile_cols`, independent of `K`.
209///
210/// Falls back to the per-row CPU `top_s_online` under [`gam_gpu::GpuMode::Off`],
211/// below the device break-even, or on any device error under
212/// [`gam_gpu::GpuMode::Auto`]; under [`gam_gpu::GpuMode::Required`] a device
213/// failure is propagated. The returned [`ScoreBlockPath`] reports which path ran.
214///
215/// # Errors
216/// Returns [`gam_gpu::GpuError`] when [`gam_gpu::GpuMode::Required`] is set but
217/// the device path cannot run for this minibatch.
218pub fn route_minibatch_required(
219    rows: ArrayView2<'_, f32>,
220    decoder: ArrayView2<'_, f32>,
221    s: usize,
222    tile: usize,
223    mode: gam_gpu::GpuMode,
224) -> Result<(Vec<Vec<(u32, f32)>>, ScoreBlockPath), gam_gpu::GpuError> {
225    use super::scoring::{TopSSelector, top_s_online};
226
227    let m = rows.nrows();
228    let k = decoder.nrows();
229
230    // CPU per-row path (bit-identical oracle), used for Off / below break-even /
231    // Auto device-error fallback.
232    let cpu_route = || -> Vec<Vec<(u32, f32)>> {
233        rows.outer_iter()
234            .map(|row| top_s_online(row, decoder, s, tile))
235            .collect()
236    };
237
238    if mode == gam_gpu::GpuMode::Off {
239        return Ok((cpu_route(), ScoreBlockPath::Cpu));
240    }
241
242    // Engagement is decided on the TOTAL work `m × K` (that is what justifies the
243    // device's fixed launch cost), but the launches themselves are K-tiled so the
244    // buffers never grow with K.
245    let elems = m.saturating_mul(k);
246    let below_breakeven = elems < DEVICE_SCORE_BLOCK_MIN_ELEMS;
247    if below_breakeven {
248        if mode == gam_gpu::GpuMode::Required {
249            return Err(gam_gpu::gpu_err!(
250                "route_minibatch GpuMode::Required: block of {m}×{k} = {elems} elems is below \
251                 the device launch break-even (DEVICE_SCORE_BLOCK_MIN_ELEMS={DEVICE_SCORE_BLOCK_MIN_ELEMS}); \
252                 refusing to silently run on the CPU"
253            ));
254        }
255        return Ok((cpu_route(), ScoreBlockPath::Cpu));
256    }
257    if m == 0 || k == 0 {
258        return Ok((cpu_route(), ScoreBlockPath::Cpu));
259    }
260
261    // Atom-columns per device launch: bound the per-launch block to
262    // GPU_ROUTE_TILE_ELEMS, at least one column, never more than K.
263    let tile_cols = (GPU_ROUTE_TILE_ELEMS / m).clamp(1, k);
264
265    // Per-row online selectors; each device tile's scores are folded in ascending
266    // global atom order (offset + ascending local), and the selector's result is
267    // tile-order-invariant, so the support is bit-identical to top_s_online.
268    let mut selectors: Vec<TopSSelector> = (0..m).map(|_| TopSSelector::new(s)).collect();
269    let mut start = 0usize;
270    while start < k {
271        let end = (start + tile_cols).min(k);
272        let atoms_tile = decoder.slice(ndarray::s![start..end, ..]);
273        match device::score_block_device(rows, atoms_tile) {
274            Ok(block) => {
275                let cols = end - start;
276                for (r, sel) in selectors.iter_mut().enumerate() {
277                    let base = r * cols;
278                    for (local, score) in block[base..base + cols].iter().enumerate() {
279                        sel.offer((start + local) as u32, *score);
280                    }
281                }
282            }
283            Err(err) => {
284                if mode == gam_gpu::GpuMode::Required {
285                    return Err(err);
286                }
287                // Auto: the device faulted mid-route; discard partial selectors and
288                // run the exact CPU oracle for the whole minibatch.
289                return Ok((cpu_route(), ScoreBlockPath::Cpu));
290            }
291        }
292        start = end;
293    }
294
295    let routed = selectors.into_iter().map(TopSSelector::finish).collect();
296    Ok((routed, ScoreBlockPath::Device))
297}
298
299mod device {
300    use super::score_block_kernel_source;
301    use gam_gpu::gpu_error::{GpuError, GpuResultExt};
302    use ndarray::ArrayView2;
303    use std::collections::HashMap;
304    use std::sync::{Arc, Mutex, OnceLock};
305
306    use cudarc::driver::{CudaContext, CudaModule, CudaStream, LaunchConfig, PushKernelArg};
307
308    struct Backend {
309        ctx: Arc<CudaContext>,
310        stream: Arc<CudaStream>,
311        modules: Mutex<HashMap<usize, Arc<CudaModule>>>,
312    }
313
314    fn backend() -> Result<&'static Backend, GpuError> {
315        static BACKEND: OnceLock<Result<Backend, GpuError>> = OnceLock::new();
316        BACKEND
317            .get_or_init(|| {
318                let parts = gam_gpu::backend_probe::probe_cuda_backend("sparse_dict_score_block")?;
319                Ok(Backend {
320                    ctx: parts.ctx,
321                    stream: parts.stream,
322                    modules: Mutex::new(HashMap::new()),
323                })
324            })
325            .as_ref()
326            .map_err(GpuError::clone)
327    }
328
329    fn module_for(b: &Backend, p: usize) -> Result<Arc<CudaModule>, GpuError> {
330        if let Ok(guard) = b.modules.lock() {
331            if let Some(m) = guard.get(&p) {
332                return Ok(m.clone());
333            }
334        }
335        let ptx = cudarc::nvrtc::compile_ptx(score_block_kernel_source(p))
336            .gpu_ctx_with(|err| format!("sparse_dict score-block NVRTC (P={p}): {err}"))?;
337        let module = b
338            .ctx
339            .load_module(ptx)
340            .gpu_ctx("sparse_dict score-block module load")?;
341        if let Ok(mut guard) = b.modules.lock() {
342            guard.entry(p).or_insert_with(|| module.clone());
343        }
344        Ok(module)
345    }
346
347    /// Compute the `n_rows × n_atoms` score block on the device. Flattens the
348    /// two views row-major (the kernel reads them as `[*, PP]`), launches one
349    /// thread per output element, and downloads the block.
350    pub(super) fn score_block_device(
351        rows: ArrayView2<'_, f32>,
352        atoms: ArrayView2<'_, f32>,
353    ) -> Result<Vec<f32>, GpuError> {
354        let n_rows = rows.nrows();
355        let n_atoms = atoms.nrows();
356        let p = rows.ncols();
357        if p != atoms.ncols() {
358            return Err(gam_gpu::gpu_err!(
359                "sparse_dict score-block: P mismatch rows={p} atoms={}",
360                atoms.ncols()
361            ));
362        }
363        if n_rows == 0 || n_atoms == 0 || p == 0 {
364            return Ok(vec![0.0f32; n_rows * n_atoms]);
365        }
366
367        let b = backend()?;
368        let module = module_for(b, p)?;
369        let func = module
370            .load_function("sparse_dict_score_block")
371            .gpu_ctx("sparse_dict score-block load_function")?;
372        let stream = b.stream.clone();
373
374        // Row-major contiguous host buffers (handles non-contiguous views).
375        let rows_host: Vec<f32> = rows.iter().copied().collect();
376        let atoms_host: Vec<f32> = atoms.iter().copied().collect();
377        assert_eq!(rows_host.len(), n_rows * p, "score-block rows flatten length");
378        assert_eq!(
379            atoms_host.len(),
380            n_atoms * p,
381            "score-block atoms flatten length"
382        );
383
384        let rows_dev = stream
385            .clone_htod(&rows_host)
386            .gpu_ctx("sparse_dict score-block htod rows")?;
387        let atoms_dev = stream
388            .clone_htod(&atoms_host)
389            .gpu_ctx("sparse_dict score-block htod atoms")?;
390        let mut scores_dev = stream
391            .alloc_zeros::<f32>(n_rows * n_atoms)
392            .gpu_ctx("sparse_dict score-block alloc scores")?;
393
394        let n_rows_i32 = i32::try_from(n_rows)
395            .map_err(|_| gam_gpu::gpu_err!("sparse_dict score-block n_rows={n_rows} overflows i32"))?;
396        let n_atoms_i32 = i32::try_from(n_atoms).map_err(|_| {
397            gam_gpu::gpu_err!("sparse_dict score-block n_atoms={n_atoms} overflows i32")
398        })?;
399
400        let total = n_rows * n_atoms;
401        let block: u32 = 256;
402        let grid: u32 = u32::try_from(total.div_ceil(block as usize))
403            .map_err(|_| gam_gpu::gpu_err!("sparse_dict score-block grid overflow"))?;
404        let cfg = LaunchConfig {
405            grid_dim: (grid, 1, 1),
406            block_dim: (block, 1, 1),
407            shared_mem_bytes: 0,
408        };
409        let mut builder = stream.launch_builder(&func);
410        builder
411            .arg(&rows_dev)
412            .arg(&atoms_dev)
413            .arg(&n_rows_i32)
414            .arg(&n_atoms_i32)
415            .arg(&mut scores_dev);
416        // SAFETY: grid/block validated; all device pointers are cudarc-checked
417        // allocations on this stream; the kernel reads rows[0..n_rows*P] /
418        // atoms[0..n_atoms*P] and writes within scores[0..n_rows*n_atoms].
419        unsafe { builder.launch(cfg) }.gpu_ctx("sparse_dict score-block launch")?;
420
421        let mut scores = vec![0.0f32; n_rows * n_atoms];
422        stream
423            .memcpy_dtoh(&scores_dev, &mut scores)
424            .gpu_ctx("sparse_dict score-block dtoh scores")?;
425        stream
426            .synchronize()
427            .gpu_ctx("sparse_dict score-block synchronize")?;
428        Ok(scores)
429    }
430}
431
432#[cfg(test)]
433mod tests {
434    use super::*;
435    use ndarray::Array2;
436
437    /// Deterministic fp32 fixture: `n_rows × p` rows and `n_atoms × p` unit-norm
438    /// atoms (the lane unit-norms its decoder, so |xᵀd| is the projection).
439    fn fixture(n_rows: usize, n_atoms: usize, p: usize) -> (Array2<f32>, Array2<f32>) {
440        let rows = Array2::from_shape_fn((n_rows, p), |(i, c)| {
441            (((i * 31 + c * 17) as f32) * 0.013).sin() * 0.9
442        });
443        let mut atoms = Array2::from_shape_fn((n_atoms, p), |(a, c)| {
444            (((a * 7 + c * 5) as f32) * 0.011).cos()
445        });
446        for mut row in atoms.outer_iter_mut() {
447            let norm = row.iter().map(|v| v * v).sum::<f32>().sqrt().max(1e-12);
448            row.mapv_inplace(|v| v / norm);
449        }
450        (rows, atoms)
451    }
452
453    #[test]
454    fn cpu_score_block_matches_score_row_tile() {
455        // The block oracle must equal the per-atom CPU router primitive exactly.
456        use crate::sparse_dict::scoring::score_row_tile;
457        let (rows, atoms) = fixture(5, 9, 7);
458        let block = score_block_cpu(rows.view(), atoms.view());
459        for r in 0..rows.nrows() {
460            // score_row_tile folds into a selector; reproduce its raw scores by
461            // running the same acc loop it uses (separate mul/add, ascending c).
462            for a in 0..atoms.nrows() {
463                let mut acc = 0.0f32;
464                for c in 0..rows.ncols() {
465                    acc += rows[[r, c]] * atoms[[a, c]];
466                }
467                assert_eq!(
468                    block[r * atoms.nrows() + a].to_bits(),
469                    acc.to_bits(),
470                    "block oracle vs raw acc differ at r={r} a={a}"
471                );
472            }
473        }
474        // And score_row_tile's selection over the full block is reproducible
475        // from the same scores (sanity: the primitive is the one we accelerate).
476        let mut sel = crate::sparse_dict::scoring::TopSSelector::new(3);
477        score_row_tile(rows.row(0), atoms.view(), 0, &mut sel);
478        let picked = sel.finish();
479        assert!(picked.len() <= 3 && !picked.is_empty());
480    }
481
482    #[cfg(target_os = "linux")]
483    #[test]
484    fn device_route_minibatch_matches_cpu_top_s_online() {
485        // The router primitive the fit loop actually calls. The m×K block MUST
486        // clear DEVICE_SCORE_BLOCK_MIN_ELEMS so the device path is admitted. On a
487        // CUDA host we drive Required (silent CPU fallback = hard failure) and
488        // assert the routed top-s support EQUALS the per-row CPU `top_s_online`
489        // oracle exactly — same atoms, same bit-identical scores, same order.
490        use crate::sparse_dict::scoring::top_s_online;
491
492        let m = 512usize;
493        let k = 4096usize; // 512*4096 = 2,097,152 >= DEVICE_SCORE_BLOCK_MIN_ELEMS
494        let p = 48usize;
495        let s = 4usize;
496        let tile = 1024usize;
497        assert!(m * k >= DEVICE_SCORE_BLOCK_MIN_ELEMS);
498        let (rows, atoms) = fixture(m, k, p);
499
500        let cpu: Vec<Vec<(u32, f32)>> = rows
501            .outer_iter()
502            .map(|row| top_s_online(row, atoms.view(), s, tile))
503            .collect();
504
505        match route_minibatch_required(
506            rows.view(),
507            atoms.view(),
508            s,
509            tile,
510            gam_gpu::GpuMode::Required,
511        ) {
512            Ok((routed, path)) => {
513                assert_eq!(
514                    path,
515                    ScoreBlockPath::Device,
516                    "Required succeeded but reported CPU — device did not engage"
517                );
518                assert_eq!(routed.len(), cpu.len());
519                for (r, (dev_sel, cpu_sel)) in routed.iter().zip(&cpu).enumerate() {
520                    assert_eq!(
521                        dev_sel.len(),
522                        cpu_sel.len(),
523                        "row {r}: selection length differs"
524                    );
525                    for (j, ((da, ds), (ca, cs))) in dev_sel.iter().zip(cpu_sel).enumerate() {
526                        assert_eq!(da, ca, "row {r} slot {j}: atom differs dev={da} cpu={ca}");
527                        assert_eq!(
528                            ds.to_bits(),
529                            cs.to_bits(),
530                            "row {r} slot {j}: score bits differ dev={ds} cpu={cs}"
531                        );
532                    }
533                }
534            }
535            Err(err) => {
536                assert!(
537                    gam_gpu::GpuRuntime::global().is_none(),
538                    "Required errored despite a live CUDA runtime: {err}"
539                );
540                // Device absent: Auto must reproduce the CPU oracle exactly.
541                let (routed, path) = route_minibatch_required(
542                    rows.view(),
543                    atoms.view(),
544                    s,
545                    tile,
546                    gam_gpu::GpuMode::Auto,
547                )
548                .expect("Auto must not error on a device-absent host");
549                assert_eq!(path, ScoreBlockPath::Cpu);
550                assert_eq!(routed, cpu);
551            }
552        }
553    }
554
555    #[cfg(target_os = "linux")]
556    #[test]
557    fn device_route_at_issue_target_k_32k_is_bit_identical() {
558        // #1026 HEADLINE SCALE. The issue is about "large linear SAEs" — K up to
559        // ~32_000. Our other parity test pins K=4096; this one drives the router
560        // at the issue's actual target width (K=32_768) to prove the device path
561        // not only engages but stays BIT-IDENTICAL to the per-row CPU oracle at
562        // the scale where the 1e4–1e6× hardware gap the issue tracks lives. m is
563        // kept modest (256) so the 256×32_768 = 8.4M-element block clears the
564        // device break-even by 8× while the host buffers stay ~34 MB.
565        use crate::sparse_dict::scoring::top_s_online;
566
567        let m = 256usize;
568        let k = 32_768usize; // 256 * 32_768 = 8,388,608 >> DEVICE_SCORE_BLOCK_MIN_ELEMS
569        let p = 64usize;
570        let s = 4usize;
571        let tile = 2048usize;
572        assert!(m * k >= DEVICE_SCORE_BLOCK_MIN_ELEMS);
573        let (rows, atoms) = fixture(m, k, p);
574
575        let cpu: Vec<Vec<(u32, f32)>> = rows
576            .outer_iter()
577            .map(|row| top_s_online(row, atoms.view(), s, tile))
578            .collect();
579
580        match route_minibatch_required(rows.view(), atoms.view(), s, tile, gam_gpu::GpuMode::Required)
581        {
582            Ok((routed, path)) => {
583                assert_eq!(
584                    path,
585                    ScoreBlockPath::Device,
586                    "Required succeeded at K=32k but reported CPU — device did not engage"
587                );
588                assert_eq!(routed.len(), cpu.len());
589                for (r, (dev_sel, cpu_sel)) in routed.iter().zip(&cpu).enumerate() {
590                    assert_eq!(dev_sel.len(), cpu_sel.len(), "row {r}: selection length differs");
591                    for (j, ((da, ds), (ca, cs))) in dev_sel.iter().zip(cpu_sel).enumerate() {
592                        assert_eq!(da, ca, "K=32k row {r} slot {j}: atom differs dev={da} cpu={ca}");
593                        assert_eq!(
594                            ds.to_bits(),
595                            cs.to_bits(),
596                            "K=32k row {r} slot {j}: score bits differ dev={ds} cpu={cs}"
597                        );
598                    }
599                }
600            }
601            Err(err) => {
602                assert!(
603                    gam_gpu::GpuRuntime::global().is_none(),
604                    "Required errored at K=32k despite a live CUDA runtime: {err}"
605                );
606                let (routed, path) = route_minibatch_required(
607                    rows.view(),
608                    atoms.view(),
609                    s,
610                    tile,
611                    gam_gpu::GpuMode::Auto,
612                )
613                .expect("Auto must not error on a device-absent host");
614                assert_eq!(path, ScoreBlockPath::Cpu);
615                assert_eq!(routed, cpu);
616            }
617        }
618    }
619
620    #[cfg(target_os = "linux")]
621    #[test]
622    fn device_score_block_is_bit_identical_to_cpu_when_available() {
623        // Exactness gate. The block MUST clear DEVICE_SCORE_BLOCK_MIN_ELEMS so
624        // the device path is actually admitted (a sub-break-even block would
625        // skip-pass on the CPU and prove nothing). On a CUDA host we drive
626        // GpuMode::Required so a silent CPU fallback is a hard FAILURE, and we
627        // assert the device block is BIT-IDENTICAL to the CPU reference. With no
628        // runtime, Required must fail closed and the CPU path stays exact.
629        let n_rows = 256;
630        let n_atoms = 4096; // 256*4096 = 1,048,576 == DEVICE_SCORE_BLOCK_MIN_ELEMS
631        let p = 48;
632        assert!(n_rows * n_atoms >= DEVICE_SCORE_BLOCK_MIN_ELEMS);
633        let (rows, atoms) = fixture(n_rows, n_atoms, p);
634        let cpu = score_block_cpu(rows.view(), atoms.view());
635
636        match score_block_required(rows.view(), atoms.view(), gam_gpu::GpuMode::Required) {
637            Ok((got, path)) => {
638                assert_eq!(
639                    path,
640                    ScoreBlockPath::Device,
641                    "Required succeeded but reported CPU — device did not engage"
642                );
643                assert_eq!(got.len(), cpu.len());
644                for (i, (g, c)) in got.iter().zip(&cpu).enumerate() {
645                    assert_eq!(
646                        g.to_bits(),
647                        c.to_bits(),
648                        "device vs CPU score-block bit mismatch at {i}: dev={g} cpu={c}"
649                    );
650                }
651            }
652            Err(err) => {
653                // No CUDA runtime on this host: Required correctly failed closed.
654                assert!(
655                    gam_gpu::GpuRuntime::global().is_none(),
656                    "Required errored despite a live CUDA runtime: {err}"
657                );
658                // The CPU path must still be exact under Auto.
659                let (got, path) =
660                    score_block_required(rows.view(), atoms.view(), gam_gpu::GpuMode::Auto)
661                        .expect("Auto must not error on a device-absent host");
662                assert_eq!(path, ScoreBlockPath::Cpu);
663                assert_eq!(got, cpu);
664            }
665        }
666    }
667}