gam_sae/sparse_dict/scoring_gpu.rs
1//! GPU score-block kernel for the collapsed-linear-lane router (#1026).
2//!
3//! The collapsed linear lane ([`crate::sparse_dict`]) scales a linear SAE
4//! dictionary to `K ≈ 32_000` atoms by routing each row against the WHOLE
5//! dictionary one atom-tile at a time and keeping only the top-`s` atoms online.
6//! That route step — `scores[r][a] = Σ_c x[r][c]·decoder[a][c]` over a
7//! `rows × tile × P` block — is the dominant cost of a fit (a single fit at
8//! `K ≈ 32k` is the measured 1e4–1e6× hardware gap the issue tracks) and is the
9//! embarrassingly-parallel shape a GPU exists for.
10//!
11//! # What this offloads (and what it does NOT)
12//!
13//! This computes ONE atom-tile's `rows × tile` score block on the device,
14//! exactly the block the CPU [`super::scoring::score_row_tile`] folds into the
15//! per-row online top-`s` selectors. The selection logic itself
16//! ([`super::scoring::TopSSelector`]) stays single-sourced on the CPU: the
17//! device returns the score block, the host folds it into the SAME selectors,
18//! and discards it. The minibatch router [`route_minibatch_required`] walks the
19//! whole `K`-wide dictionary in atom-column tiles (each launch's block capped at
20//! `GPU_ROUTE_TILE_ELEMS`), so peak host/device score memory is `rows × tile`,
21//! **independent of `K`** — the lane's no-`N×K` memory discipline is preserved
22//! on the device exactly as on the CPU; the GPU just does the `O(rows·tile·P)`
23//! multiply-accumulate that dominates.
24//!
25//! # Bit-exact parity (the gate, not a tolerance)
26//!
27//! The CPU oracle accumulates `acc += x[c]·d[c]` as SEPARATE f32 multiply then
28//! f32 add (Rust emits no fused multiply-add for `a*b+c` unless `f32::mul_add`
29//! is called, and `-ffp-contract` is off), in ascending `c` order. NVRTC
30//! defaults to `--fmad=true`, which contracts `a*b+c` into a single-rounding
31//! FMA — a ~1 ULP difference that can flip a near-tie top-`s` selection and make
32//! the routed support, and hence the whole fit, diverge from the CPU oracle.
33//!
34//! So the kernel forces SEPARATE rounding with `__fmul_rn` + `__fadd_rn`, in the
35//! SAME ascending-`c` order, giving a score block that is **bit-for-bit**
36//! identical to the CPU `score_row_tile` (every `f32` equal under `to_bits`).
37//! Because the scores are identical, the CPU selector fed device scores produces
38//! the IDENTICAL routed support — parity is exact by construction, not bounded
39//! by a tolerance.
40
41#![cfg(target_os = "linux")]
42
43use ndarray::ArrayView2;
44
45/// The bit-exact-parity NVRTC kernel. One thread per `(row, atom)` output;
46/// accumulates over `P` columns in ascending order with separate-rounding f32
47/// ops so the result matches the CPU sequential `acc += x·d` to the bit.
48///
49/// `PP` (the column count) is baked in as a `#define` so the inner loop is a
50/// fixed trip count (matching the other NVRTC kernels in this repo, which
51/// monomorphise their shape macros for a pure `compile_ptx`).
52pub const SCORE_BLOCK_KERNEL_SOURCE: &str = r#"
53extern "C" __global__
54void sparse_dict_score_block(
55 const float* __restrict__ rows, // [n_rows * PP] row-major
56 const float* __restrict__ atoms, // [n_atoms * PP] row-major (decoder tile)
57 int n_rows,
58 int n_atoms,
59 float* __restrict__ scores) // [n_rows * n_atoms] row-major
60{
61 long long idx = (long long)blockIdx.x * blockDim.x + threadIdx.x;
62 long long total = (long long)n_rows * (long long)n_atoms;
63 if (idx >= total) return;
64 int r = (int)(idx / n_atoms);
65 int a = (int)(idx % n_atoms);
66 const float* xr = rows + (long long)r * PP;
67 const float* da = atoms + (long long)a * PP;
68 // SEPARATE-rounding accumulation in ascending c — NO fused multiply-add, so
69 // this is bit-identical to the CPU `acc += x[c]*d[c]` reference order.
70 float acc = 0.0f;
71 for (int c = 0; c < PP; ++c) {
72 float prod = __fmul_rn(xr[c], da[c]);
73 acc = __fadd_rn(acc, prod);
74 }
75 scores[idx] = acc;
76}
77"#;
78
79/// Prepend the `PP` shape macro so the NVRTC compile is a pure `compile_ptx`
80/// (mirrors `sae_rowjet::softmax_kernel_source` / `arrow_schur_nvrtc`).
81#[must_use]
82pub fn score_block_kernel_source(p: usize) -> String {
83 format!("#define PP {p}\n{SCORE_BLOCK_KERNEL_SOURCE}")
84}
85
86/// CPU reference for the score block: `scores[r*n_atoms + a] = Σ_c
87/// rows[r][c]·atoms[a][c]`, accumulated in ascending `c` with separate f32
88/// rounding — the SAME arithmetic [`super::scoring::score_row_tile`] runs
89/// per atom. This is the parity oracle the device kernel is locked against.
90#[must_use]
91pub fn score_block_cpu(rows: ArrayView2<'_, f32>, atoms: ArrayView2<'_, f32>) -> Vec<f32> {
92 let n_rows = rows.nrows();
93 let n_atoms = atoms.nrows();
94 let p = rows.ncols();
95 assert_eq!(p, atoms.ncols(), "score_block_cpu: P mismatch rows vs atoms");
96 let mut scores = vec![0.0f32; n_rows * n_atoms];
97 for r in 0..n_rows {
98 let xr = rows.row(r);
99 for a in 0..n_atoms {
100 let da = atoms.row(a);
101 let mut acc = 0.0f32;
102 for c in 0..p {
103 // separate mul then add — matches the kernel's __fmul_rn/__fadd_rn
104 acc += xr[c] * da[c];
105 }
106 scores[r * n_atoms + a] = acc;
107 }
108 }
109 scores
110}
111
112/// Minimum score-block element count (`n_rows · n_atoms`) below which the device
113/// launch is not worth its fixed cost (probe + H2D + D2H). Below this the CPU
114/// reference is used. Tuned to the same genus as the other SAE device floors
115/// (`sae_rowjet::DEVICE_ROW_THRESHOLD`).
116pub const DEVICE_SCORE_BLOCK_MIN_ELEMS: usize = 1 << 20; // ~1M MACs/tile minimum
117
118/// Which path produced a score block. Returned by the fail-loud entry point so
119/// callers (and the parity test) can ASSERT the device engaged rather than
120/// silently falling back — the #1026/#1551 'GPU 0%' failure mode.
121#[derive(Debug, Clone, Copy, PartialEq, Eq)]
122pub enum ScoreBlockPath {
123 /// The NVRTC `sparse_dict_score_block` kernel ran on the device.
124 Device,
125 /// The CPU `score_block_cpu` reference ran.
126 Cpu,
127}
128
129/// Fail-loud, residency-aware score-block entry point (#1026 scale-K lane).
130///
131/// Honours the process-wide [`gam_gpu::GpuMode`] contract: under
132/// [`gam_gpu::GpuMode::Required`] a missing CUDA runtime, an NVRTC/arch compile
133/// failure, a launch fault, or a block below the device break-even all return
134/// `Err` instead of silently degrading to the CPU. [`gam_gpu::GpuMode::Auto`]
135/// uses the device when admitted and the block clears the break-even, else the
136/// CPU; [`gam_gpu::GpuMode::Off`] always the CPU. The returned [`ScoreBlockPath`]
137/// reports which path actually ran.
138///
139/// Both paths produce a BIT-IDENTICAL `f32` score block (see module docs), so
140/// the routed top-`s` support is identical whichever path runs.
141///
142/// # Errors
143/// Returns [`gam_gpu::GpuError`] when [`gam_gpu::GpuMode::Required`] is set but
144/// the device path cannot run.
145pub fn score_block_required(
146 rows: ArrayView2<'_, f32>,
147 atoms: ArrayView2<'_, f32>,
148 mode: gam_gpu::GpuMode,
149) -> Result<(Vec<f32>, ScoreBlockPath), gam_gpu::GpuError> {
150 use gam_gpu::GpuMode;
151
152 let n_rows = rows.nrows();
153 let n_atoms = atoms.nrows();
154 let elems = n_rows.saturating_mul(n_atoms);
155
156 if mode == GpuMode::Off {
157 return Ok((score_block_cpu(rows, atoms), ScoreBlockPath::Cpu));
158 }
159
160 let below_breakeven = elems < DEVICE_SCORE_BLOCK_MIN_ELEMS;
161 if mode == GpuMode::Required && below_breakeven {
162 return Err(gam_gpu::gpu_err!(
163 "sparse_dict score-block GpuMode::Required: block of {n_rows}×{n_atoms} \
164 = {elems} elems is below the device launch break-even \
165 (DEVICE_SCORE_BLOCK_MIN_ELEMS={DEVICE_SCORE_BLOCK_MIN_ELEMS}); refusing \
166 to silently run on the CPU"
167 ));
168 }
169 if !below_breakeven {
170 match device::score_block_device(rows, atoms) {
171 Ok(out) => return Ok((out, ScoreBlockPath::Device)),
172 Err(err) => {
173 if mode == GpuMode::Required {
174 return Err(err);
175 }
176 // Auto: fall through to the CPU.
177 }
178 }
179 }
180
181 Ok((score_block_cpu(rows, atoms), ScoreBlockPath::Cpu))
182}
183
184/// Peak score elements per device launch for the tiled GPU router. The router
185/// NEVER materialises the whole `m × K` block: it walks `K` in atom-column tiles
186/// sized so each launch's `m × cols` block stays under this cap (~2M f32 ≈ 8 MB
187/// host + 8 MB device), then discards it after folding. This keeps peak score
188/// memory bounded **independent of `K`** — the same discipline the CPU lane
189/// ([`super::scoring::top_s_online`]) keeps with its `rows × tile` column tiles —
190/// so a `K ≈ 32_000` fit does not balloon a `device alloc` linearly in `K`.
191const GPU_ROUTE_TILE_ELEMS: usize = 1 << 21;
192
193/// Route a whole minibatch of rows against the full decoder, returning each
194/// row's top-`s` `(atom, score)` selection — BIT-IDENTICAL to calling
195/// [`super::scoring::top_s_online`] per row, but with the score block computed on
196/// the device (in `K`-tiled launches) when admitted.
197///
198/// The selection ([`super::scoring::TopSSelector`]) is single-sourced on the CPU
199/// and fed the device scores in ascending atom order. `TopSSelector` keeps the
200/// top-`s` by `(|score| desc, atom asc)` — a strict total order on the unique
201/// atom indices — so the selected set is independent of the order or the tiling
202/// in which candidates are offered. Combined with bit-identical scores (the
203/// kernel forbids FMA contraction), the routed support matches the CPU oracle
204/// **exactly**, whichever path and whatever GPU tile width computed the block.
205///
206/// Memory: the `m × K` block is never formed whole — `K` is walked in tiles of
207/// at most `GPU_ROUTE_TILE_ELEMS / m` atom-columns, each launched, folded, and
208/// discarded, so peak score memory is `m × tile_cols`, independent of `K`.
209///
210/// Falls back to the per-row CPU `top_s_online` under [`gam_gpu::GpuMode::Off`],
211/// below the device break-even, or on any device error under
212/// [`gam_gpu::GpuMode::Auto`]; under [`gam_gpu::GpuMode::Required`] a device
213/// failure is propagated. The returned [`ScoreBlockPath`] reports which path ran.
214///
215/// # Errors
216/// Returns [`gam_gpu::GpuError`] when [`gam_gpu::GpuMode::Required`] is set but
217/// the device path cannot run for this minibatch.
218pub fn route_minibatch_required(
219 rows: ArrayView2<'_, f32>,
220 decoder: ArrayView2<'_, f32>,
221 s: usize,
222 tile: usize,
223 mode: gam_gpu::GpuMode,
224) -> Result<(Vec<Vec<(u32, f32)>>, ScoreBlockPath), gam_gpu::GpuError> {
225 use super::scoring::{TopSSelector, top_s_online};
226
227 let m = rows.nrows();
228 let k = decoder.nrows();
229
230 // CPU per-row path (bit-identical oracle), used for Off / below break-even /
231 // Auto device-error fallback.
232 let cpu_route = || -> Vec<Vec<(u32, f32)>> {
233 rows.outer_iter()
234 .map(|row| top_s_online(row, decoder, s, tile))
235 .collect()
236 };
237
238 if mode == gam_gpu::GpuMode::Off {
239 return Ok((cpu_route(), ScoreBlockPath::Cpu));
240 }
241
242 // Engagement is decided on the TOTAL work `m × K` (that is what justifies the
243 // device's fixed launch cost), but the launches themselves are K-tiled so the
244 // buffers never grow with K.
245 let elems = m.saturating_mul(k);
246 let below_breakeven = elems < DEVICE_SCORE_BLOCK_MIN_ELEMS;
247 if below_breakeven {
248 if mode == gam_gpu::GpuMode::Required {
249 return Err(gam_gpu::gpu_err!(
250 "route_minibatch GpuMode::Required: block of {m}×{k} = {elems} elems is below \
251 the device launch break-even (DEVICE_SCORE_BLOCK_MIN_ELEMS={DEVICE_SCORE_BLOCK_MIN_ELEMS}); \
252 refusing to silently run on the CPU"
253 ));
254 }
255 return Ok((cpu_route(), ScoreBlockPath::Cpu));
256 }
257 if m == 0 || k == 0 {
258 return Ok((cpu_route(), ScoreBlockPath::Cpu));
259 }
260
261 // Atom-columns per device launch: bound the per-launch block to
262 // GPU_ROUTE_TILE_ELEMS, at least one column, never more than K.
263 let tile_cols = (GPU_ROUTE_TILE_ELEMS / m).clamp(1, k);
264
265 // Per-row online selectors; each device tile's scores are folded in ascending
266 // global atom order (offset + ascending local), and the selector's result is
267 // tile-order-invariant, so the support is bit-identical to top_s_online.
268 let mut selectors: Vec<TopSSelector> = (0..m).map(|_| TopSSelector::new(s)).collect();
269 let mut start = 0usize;
270 while start < k {
271 let end = (start + tile_cols).min(k);
272 let atoms_tile = decoder.slice(ndarray::s![start..end, ..]);
273 match device::score_block_device(rows, atoms_tile) {
274 Ok(block) => {
275 let cols = end - start;
276 for (r, sel) in selectors.iter_mut().enumerate() {
277 let base = r * cols;
278 for (local, score) in block[base..base + cols].iter().enumerate() {
279 sel.offer((start + local) as u32, *score);
280 }
281 }
282 }
283 Err(err) => {
284 if mode == gam_gpu::GpuMode::Required {
285 return Err(err);
286 }
287 // Auto: the device faulted mid-route; discard partial selectors and
288 // run the exact CPU oracle for the whole minibatch.
289 return Ok((cpu_route(), ScoreBlockPath::Cpu));
290 }
291 }
292 start = end;
293 }
294
295 let routed = selectors.into_iter().map(TopSSelector::finish).collect();
296 Ok((routed, ScoreBlockPath::Device))
297}
298
299mod device {
300 use super::score_block_kernel_source;
301 use gam_gpu::gpu_error::{GpuError, GpuResultExt};
302 use ndarray::ArrayView2;
303 use std::collections::HashMap;
304 use std::sync::{Arc, Mutex, OnceLock};
305
306 use cudarc::driver::{CudaContext, CudaModule, CudaStream, LaunchConfig, PushKernelArg};
307
308 struct Backend {
309 ctx: Arc<CudaContext>,
310 stream: Arc<CudaStream>,
311 modules: Mutex<HashMap<usize, Arc<CudaModule>>>,
312 }
313
314 fn backend() -> Result<&'static Backend, GpuError> {
315 static BACKEND: OnceLock<Result<Backend, GpuError>> = OnceLock::new();
316 BACKEND
317 .get_or_init(|| {
318 let parts = gam_gpu::backend_probe::probe_cuda_backend("sparse_dict_score_block")?;
319 Ok(Backend {
320 ctx: parts.ctx,
321 stream: parts.stream,
322 modules: Mutex::new(HashMap::new()),
323 })
324 })
325 .as_ref()
326 .map_err(GpuError::clone)
327 }
328
329 fn module_for(b: &Backend, p: usize) -> Result<Arc<CudaModule>, GpuError> {
330 if let Ok(guard) = b.modules.lock() {
331 if let Some(m) = guard.get(&p) {
332 return Ok(m.clone());
333 }
334 }
335 let ptx = cudarc::nvrtc::compile_ptx(score_block_kernel_source(p))
336 .gpu_ctx_with(|err| format!("sparse_dict score-block NVRTC (P={p}): {err}"))?;
337 let module = b
338 .ctx
339 .load_module(ptx)
340 .gpu_ctx("sparse_dict score-block module load")?;
341 if let Ok(mut guard) = b.modules.lock() {
342 guard.entry(p).or_insert_with(|| module.clone());
343 }
344 Ok(module)
345 }
346
347 /// Compute the `n_rows × n_atoms` score block on the device. Flattens the
348 /// two views row-major (the kernel reads them as `[*, PP]`), launches one
349 /// thread per output element, and downloads the block.
350 pub(super) fn score_block_device(
351 rows: ArrayView2<'_, f32>,
352 atoms: ArrayView2<'_, f32>,
353 ) -> Result<Vec<f32>, GpuError> {
354 let n_rows = rows.nrows();
355 let n_atoms = atoms.nrows();
356 let p = rows.ncols();
357 if p != atoms.ncols() {
358 return Err(gam_gpu::gpu_err!(
359 "sparse_dict score-block: P mismatch rows={p} atoms={}",
360 atoms.ncols()
361 ));
362 }
363 if n_rows == 0 || n_atoms == 0 || p == 0 {
364 return Ok(vec![0.0f32; n_rows * n_atoms]);
365 }
366
367 let b = backend()?;
368 let module = module_for(b, p)?;
369 let func = module
370 .load_function("sparse_dict_score_block")
371 .gpu_ctx("sparse_dict score-block load_function")?;
372 let stream = b.stream.clone();
373
374 // Row-major contiguous host buffers (handles non-contiguous views).
375 let rows_host: Vec<f32> = rows.iter().copied().collect();
376 let atoms_host: Vec<f32> = atoms.iter().copied().collect();
377 assert_eq!(rows_host.len(), n_rows * p, "score-block rows flatten length");
378 assert_eq!(
379 atoms_host.len(),
380 n_atoms * p,
381 "score-block atoms flatten length"
382 );
383
384 let rows_dev = stream
385 .clone_htod(&rows_host)
386 .gpu_ctx("sparse_dict score-block htod rows")?;
387 let atoms_dev = stream
388 .clone_htod(&atoms_host)
389 .gpu_ctx("sparse_dict score-block htod atoms")?;
390 let mut scores_dev = stream
391 .alloc_zeros::<f32>(n_rows * n_atoms)
392 .gpu_ctx("sparse_dict score-block alloc scores")?;
393
394 let n_rows_i32 = i32::try_from(n_rows)
395 .map_err(|_| gam_gpu::gpu_err!("sparse_dict score-block n_rows={n_rows} overflows i32"))?;
396 let n_atoms_i32 = i32::try_from(n_atoms).map_err(|_| {
397 gam_gpu::gpu_err!("sparse_dict score-block n_atoms={n_atoms} overflows i32")
398 })?;
399
400 let total = n_rows * n_atoms;
401 let block: u32 = 256;
402 let grid: u32 = u32::try_from(total.div_ceil(block as usize))
403 .map_err(|_| gam_gpu::gpu_err!("sparse_dict score-block grid overflow"))?;
404 let cfg = LaunchConfig {
405 grid_dim: (grid, 1, 1),
406 block_dim: (block, 1, 1),
407 shared_mem_bytes: 0,
408 };
409 let mut builder = stream.launch_builder(&func);
410 builder
411 .arg(&rows_dev)
412 .arg(&atoms_dev)
413 .arg(&n_rows_i32)
414 .arg(&n_atoms_i32)
415 .arg(&mut scores_dev);
416 // SAFETY: grid/block validated; all device pointers are cudarc-checked
417 // allocations on this stream; the kernel reads rows[0..n_rows*P] /
418 // atoms[0..n_atoms*P] and writes within scores[0..n_rows*n_atoms].
419 unsafe { builder.launch(cfg) }.gpu_ctx("sparse_dict score-block launch")?;
420
421 let mut scores = vec![0.0f32; n_rows * n_atoms];
422 stream
423 .memcpy_dtoh(&scores_dev, &mut scores)
424 .gpu_ctx("sparse_dict score-block dtoh scores")?;
425 stream
426 .synchronize()
427 .gpu_ctx("sparse_dict score-block synchronize")?;
428 Ok(scores)
429 }
430}
431
432#[cfg(test)]
433mod tests {
434 use super::*;
435 use ndarray::Array2;
436
437 /// Deterministic fp32 fixture: `n_rows × p` rows and `n_atoms × p` unit-norm
438 /// atoms (the lane unit-norms its decoder, so |xᵀd| is the projection).
439 fn fixture(n_rows: usize, n_atoms: usize, p: usize) -> (Array2<f32>, Array2<f32>) {
440 let rows = Array2::from_shape_fn((n_rows, p), |(i, c)| {
441 (((i * 31 + c * 17) as f32) * 0.013).sin() * 0.9
442 });
443 let mut atoms = Array2::from_shape_fn((n_atoms, p), |(a, c)| {
444 (((a * 7 + c * 5) as f32) * 0.011).cos()
445 });
446 for mut row in atoms.outer_iter_mut() {
447 let norm = row.iter().map(|v| v * v).sum::<f32>().sqrt().max(1e-12);
448 row.mapv_inplace(|v| v / norm);
449 }
450 (rows, atoms)
451 }
452
453 #[test]
454 fn cpu_score_block_matches_score_row_tile() {
455 // The block oracle must equal the per-atom CPU router primitive exactly.
456 use crate::sparse_dict::scoring::score_row_tile;
457 let (rows, atoms) = fixture(5, 9, 7);
458 let block = score_block_cpu(rows.view(), atoms.view());
459 for r in 0..rows.nrows() {
460 // score_row_tile folds into a selector; reproduce its raw scores by
461 // running the same acc loop it uses (separate mul/add, ascending c).
462 for a in 0..atoms.nrows() {
463 let mut acc = 0.0f32;
464 for c in 0..rows.ncols() {
465 acc += rows[[r, c]] * atoms[[a, c]];
466 }
467 assert_eq!(
468 block[r * atoms.nrows() + a].to_bits(),
469 acc.to_bits(),
470 "block oracle vs raw acc differ at r={r} a={a}"
471 );
472 }
473 }
474 // And score_row_tile's selection over the full block is reproducible
475 // from the same scores (sanity: the primitive is the one we accelerate).
476 let mut sel = crate::sparse_dict::scoring::TopSSelector::new(3);
477 score_row_tile(rows.row(0), atoms.view(), 0, &mut sel);
478 let picked = sel.finish();
479 assert!(picked.len() <= 3 && !picked.is_empty());
480 }
481
482 #[cfg(target_os = "linux")]
483 #[test]
484 fn device_route_minibatch_matches_cpu_top_s_online() {
485 // The router primitive the fit loop actually calls. The m×K block MUST
486 // clear DEVICE_SCORE_BLOCK_MIN_ELEMS so the device path is admitted. On a
487 // CUDA host we drive Required (silent CPU fallback = hard failure) and
488 // assert the routed top-s support EQUALS the per-row CPU `top_s_online`
489 // oracle exactly — same atoms, same bit-identical scores, same order.
490 use crate::sparse_dict::scoring::top_s_online;
491
492 let m = 512usize;
493 let k = 4096usize; // 512*4096 = 2,097,152 >= DEVICE_SCORE_BLOCK_MIN_ELEMS
494 let p = 48usize;
495 let s = 4usize;
496 let tile = 1024usize;
497 assert!(m * k >= DEVICE_SCORE_BLOCK_MIN_ELEMS);
498 let (rows, atoms) = fixture(m, k, p);
499
500 let cpu: Vec<Vec<(u32, f32)>> = rows
501 .outer_iter()
502 .map(|row| top_s_online(row, atoms.view(), s, tile))
503 .collect();
504
505 match route_minibatch_required(
506 rows.view(),
507 atoms.view(),
508 s,
509 tile,
510 gam_gpu::GpuMode::Required,
511 ) {
512 Ok((routed, path)) => {
513 assert_eq!(
514 path,
515 ScoreBlockPath::Device,
516 "Required succeeded but reported CPU — device did not engage"
517 );
518 assert_eq!(routed.len(), cpu.len());
519 for (r, (dev_sel, cpu_sel)) in routed.iter().zip(&cpu).enumerate() {
520 assert_eq!(
521 dev_sel.len(),
522 cpu_sel.len(),
523 "row {r}: selection length differs"
524 );
525 for (j, ((da, ds), (ca, cs))) in dev_sel.iter().zip(cpu_sel).enumerate() {
526 assert_eq!(da, ca, "row {r} slot {j}: atom differs dev={da} cpu={ca}");
527 assert_eq!(
528 ds.to_bits(),
529 cs.to_bits(),
530 "row {r} slot {j}: score bits differ dev={ds} cpu={cs}"
531 );
532 }
533 }
534 }
535 Err(err) => {
536 assert!(
537 gam_gpu::GpuRuntime::global().is_none(),
538 "Required errored despite a live CUDA runtime: {err}"
539 );
540 // Device absent: Auto must reproduce the CPU oracle exactly.
541 let (routed, path) = route_minibatch_required(
542 rows.view(),
543 atoms.view(),
544 s,
545 tile,
546 gam_gpu::GpuMode::Auto,
547 )
548 .expect("Auto must not error on a device-absent host");
549 assert_eq!(path, ScoreBlockPath::Cpu);
550 assert_eq!(routed, cpu);
551 }
552 }
553 }
554
555 #[cfg(target_os = "linux")]
556 #[test]
557 fn device_route_at_issue_target_k_32k_is_bit_identical() {
558 // #1026 HEADLINE SCALE. The issue is about "large linear SAEs" — K up to
559 // ~32_000. Our other parity test pins K=4096; this one drives the router
560 // at the issue's actual target width (K=32_768) to prove the device path
561 // not only engages but stays BIT-IDENTICAL to the per-row CPU oracle at
562 // the scale where the 1e4–1e6× hardware gap the issue tracks lives. m is
563 // kept modest (256) so the 256×32_768 = 8.4M-element block clears the
564 // device break-even by 8× while the host buffers stay ~34 MB.
565 use crate::sparse_dict::scoring::top_s_online;
566
567 let m = 256usize;
568 let k = 32_768usize; // 256 * 32_768 = 8,388,608 >> DEVICE_SCORE_BLOCK_MIN_ELEMS
569 let p = 64usize;
570 let s = 4usize;
571 let tile = 2048usize;
572 assert!(m * k >= DEVICE_SCORE_BLOCK_MIN_ELEMS);
573 let (rows, atoms) = fixture(m, k, p);
574
575 let cpu: Vec<Vec<(u32, f32)>> = rows
576 .outer_iter()
577 .map(|row| top_s_online(row, atoms.view(), s, tile))
578 .collect();
579
580 match route_minibatch_required(rows.view(), atoms.view(), s, tile, gam_gpu::GpuMode::Required)
581 {
582 Ok((routed, path)) => {
583 assert_eq!(
584 path,
585 ScoreBlockPath::Device,
586 "Required succeeded at K=32k but reported CPU — device did not engage"
587 );
588 assert_eq!(routed.len(), cpu.len());
589 for (r, (dev_sel, cpu_sel)) in routed.iter().zip(&cpu).enumerate() {
590 assert_eq!(dev_sel.len(), cpu_sel.len(), "row {r}: selection length differs");
591 for (j, ((da, ds), (ca, cs))) in dev_sel.iter().zip(cpu_sel).enumerate() {
592 assert_eq!(da, ca, "K=32k row {r} slot {j}: atom differs dev={da} cpu={ca}");
593 assert_eq!(
594 ds.to_bits(),
595 cs.to_bits(),
596 "K=32k row {r} slot {j}: score bits differ dev={ds} cpu={cs}"
597 );
598 }
599 }
600 }
601 Err(err) => {
602 assert!(
603 gam_gpu::GpuRuntime::global().is_none(),
604 "Required errored at K=32k despite a live CUDA runtime: {err}"
605 );
606 let (routed, path) = route_minibatch_required(
607 rows.view(),
608 atoms.view(),
609 s,
610 tile,
611 gam_gpu::GpuMode::Auto,
612 )
613 .expect("Auto must not error on a device-absent host");
614 assert_eq!(path, ScoreBlockPath::Cpu);
615 assert_eq!(routed, cpu);
616 }
617 }
618 }
619
620 #[cfg(target_os = "linux")]
621 #[test]
622 fn device_score_block_is_bit_identical_to_cpu_when_available() {
623 // Exactness gate. The block MUST clear DEVICE_SCORE_BLOCK_MIN_ELEMS so
624 // the device path is actually admitted (a sub-break-even block would
625 // skip-pass on the CPU and prove nothing). On a CUDA host we drive
626 // GpuMode::Required so a silent CPU fallback is a hard FAILURE, and we
627 // assert the device block is BIT-IDENTICAL to the CPU reference. With no
628 // runtime, Required must fail closed and the CPU path stays exact.
629 let n_rows = 256;
630 let n_atoms = 4096; // 256*4096 = 1,048,576 == DEVICE_SCORE_BLOCK_MIN_ELEMS
631 let p = 48;
632 assert!(n_rows * n_atoms >= DEVICE_SCORE_BLOCK_MIN_ELEMS);
633 let (rows, atoms) = fixture(n_rows, n_atoms, p);
634 let cpu = score_block_cpu(rows.view(), atoms.view());
635
636 match score_block_required(rows.view(), atoms.view(), gam_gpu::GpuMode::Required) {
637 Ok((got, path)) => {
638 assert_eq!(
639 path,
640 ScoreBlockPath::Device,
641 "Required succeeded but reported CPU — device did not engage"
642 );
643 assert_eq!(got.len(), cpu.len());
644 for (i, (g, c)) in got.iter().zip(&cpu).enumerate() {
645 assert_eq!(
646 g.to_bits(),
647 c.to_bits(),
648 "device vs CPU score-block bit mismatch at {i}: dev={g} cpu={c}"
649 );
650 }
651 }
652 Err(err) => {
653 // No CUDA runtime on this host: Required correctly failed closed.
654 assert!(
655 gam_gpu::GpuRuntime::global().is_none(),
656 "Required errored despite a live CUDA runtime: {err}"
657 );
658 // The CPU path must still be exact under Auto.
659 let (got, path) =
660 score_block_required(rows.view(), atoms.view(), gam_gpu::GpuMode::Auto)
661 .expect("Auto must not error on a device-absent host");
662 assert_eq!(path, ScoreBlockPath::Cpu);
663 assert_eq!(got, cpu);
664 }
665 }
666 }
667}