#![cfg(target_os = "linux")]
use ndarray::ArrayView2;
pub const SCORE_BLOCK_KERNEL_SOURCE: &str = r#"
extern "C" __global__
void sparse_dict_score_block(
const float* __restrict__ rows, // [n_rows * PP] row-major
const float* __restrict__ atoms, // [n_atoms * PP] row-major (decoder tile)
int n_rows,
int n_atoms,
float* __restrict__ scores) // [n_rows * n_atoms] row-major
{
long long idx = (long long)blockIdx.x * blockDim.x + threadIdx.x;
long long total = (long long)n_rows * (long long)n_atoms;
if (idx >= total) return;
int r = (int)(idx / n_atoms);
int a = (int)(idx % n_atoms);
const float* xr = rows + (long long)r * PP;
const float* da = atoms + (long long)a * PP;
// SEPARATE-rounding accumulation in ascending c — NO fused multiply-add, so
// this is bit-identical to the CPU `acc += x[c]*d[c]` reference order.
float acc = 0.0f;
for (int c = 0; c < PP; ++c) {
float prod = __fmul_rn(xr[c], da[c]);
acc = __fadd_rn(acc, prod);
}
scores[idx] = acc;
}
"#;
#[must_use]
pub fn score_block_kernel_source(p: usize) -> String {
format!("#define PP {p}\n{SCORE_BLOCK_KERNEL_SOURCE}")
}
#[must_use]
pub fn score_block_cpu(rows: ArrayView2<'_, f32>, atoms: ArrayView2<'_, f32>) -> Vec<f32> {
let n_rows = rows.nrows();
let n_atoms = atoms.nrows();
let p = rows.ncols();
assert_eq!(p, atoms.ncols(), "score_block_cpu: P mismatch rows vs atoms");
let mut scores = vec![0.0f32; n_rows * n_atoms];
for r in 0..n_rows {
let xr = rows.row(r);
for a in 0..n_atoms {
let da = atoms.row(a);
let mut acc = 0.0f32;
for c in 0..p {
acc += xr[c] * da[c];
}
scores[r * n_atoms + a] = acc;
}
}
scores
}
pub const DEVICE_SCORE_BLOCK_MIN_ELEMS: usize = 1 << 20;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ScoreBlockPath {
Device,
Cpu,
}
pub fn score_block_required(
rows: ArrayView2<'_, f32>,
atoms: ArrayView2<'_, f32>,
mode: gam_gpu::GpuMode,
) -> Result<(Vec<f32>, ScoreBlockPath), gam_gpu::GpuError> {
use gam_gpu::GpuMode;
let n_rows = rows.nrows();
let n_atoms = atoms.nrows();
let elems = n_rows.saturating_mul(n_atoms);
if mode == GpuMode::Off {
return Ok((score_block_cpu(rows, atoms), ScoreBlockPath::Cpu));
}
let below_breakeven = elems < DEVICE_SCORE_BLOCK_MIN_ELEMS;
if mode == GpuMode::Required && below_breakeven {
return Err(gam_gpu::gpu_err!(
"sparse_dict score-block GpuMode::Required: block of {n_rows}×{n_atoms} \
= {elems} elems is below the device launch break-even \
(DEVICE_SCORE_BLOCK_MIN_ELEMS={DEVICE_SCORE_BLOCK_MIN_ELEMS}); refusing \
to silently run on the CPU"
));
}
if !below_breakeven {
match device::score_block_device(rows, atoms) {
Ok(out) => return Ok((out, ScoreBlockPath::Device)),
Err(err) => {
if mode == GpuMode::Required {
return Err(err);
}
}
}
}
Ok((score_block_cpu(rows, atoms), ScoreBlockPath::Cpu))
}
const GPU_ROUTE_TILE_ELEMS: usize = 1 << 21;
pub fn route_minibatch_required(
rows: ArrayView2<'_, f32>,
decoder: ArrayView2<'_, f32>,
s: usize,
tile: usize,
mode: gam_gpu::GpuMode,
) -> Result<(Vec<Vec<(u32, f32)>>, ScoreBlockPath), gam_gpu::GpuError> {
use super::scoring::{TopSSelector, top_s_online};
let m = rows.nrows();
let k = decoder.nrows();
let cpu_route = || -> Vec<Vec<(u32, f32)>> {
rows.outer_iter()
.map(|row| top_s_online(row, decoder, s, tile))
.collect()
};
if mode == gam_gpu::GpuMode::Off {
return Ok((cpu_route(), ScoreBlockPath::Cpu));
}
let elems = m.saturating_mul(k);
let below_breakeven = elems < DEVICE_SCORE_BLOCK_MIN_ELEMS;
if below_breakeven {
if mode == gam_gpu::GpuMode::Required {
return Err(gam_gpu::gpu_err!(
"route_minibatch GpuMode::Required: block of {m}×{k} = {elems} elems is below \
the device launch break-even (DEVICE_SCORE_BLOCK_MIN_ELEMS={DEVICE_SCORE_BLOCK_MIN_ELEMS}); \
refusing to silently run on the CPU"
));
}
return Ok((cpu_route(), ScoreBlockPath::Cpu));
}
if m == 0 || k == 0 {
return Ok((cpu_route(), ScoreBlockPath::Cpu));
}
let tile_cols = (GPU_ROUTE_TILE_ELEMS / m).clamp(1, k);
let mut selectors: Vec<TopSSelector> = (0..m).map(|_| TopSSelector::new(s)).collect();
let mut start = 0usize;
while start < k {
let end = (start + tile_cols).min(k);
let atoms_tile = decoder.slice(ndarray::s![start..end, ..]);
match device::score_block_device(rows, atoms_tile) {
Ok(block) => {
let cols = end - start;
for (r, sel) in selectors.iter_mut().enumerate() {
let base = r * cols;
for (local, score) in block[base..base + cols].iter().enumerate() {
sel.offer((start + local) as u32, *score);
}
}
}
Err(err) => {
if mode == gam_gpu::GpuMode::Required {
return Err(err);
}
return Ok((cpu_route(), ScoreBlockPath::Cpu));
}
}
start = end;
}
let routed = selectors.into_iter().map(TopSSelector::finish).collect();
Ok((routed, ScoreBlockPath::Device))
}
mod device {
use super::score_block_kernel_source;
use gam_gpu::gpu_error::{GpuError, GpuResultExt};
use ndarray::ArrayView2;
use std::collections::HashMap;
use std::sync::{Arc, Mutex, OnceLock};
use cudarc::driver::{CudaContext, CudaModule, CudaStream, LaunchConfig, PushKernelArg};
struct Backend {
ctx: Arc<CudaContext>,
stream: Arc<CudaStream>,
modules: Mutex<HashMap<usize, Arc<CudaModule>>>,
}
fn backend() -> Result<&'static Backend, GpuError> {
static BACKEND: OnceLock<Result<Backend, GpuError>> = OnceLock::new();
BACKEND
.get_or_init(|| {
let parts = gam_gpu::backend_probe::probe_cuda_backend("sparse_dict_score_block")?;
Ok(Backend {
ctx: parts.ctx,
stream: parts.stream,
modules: Mutex::new(HashMap::new()),
})
})
.as_ref()
.map_err(GpuError::clone)
}
fn module_for(b: &Backend, p: usize) -> Result<Arc<CudaModule>, GpuError> {
if let Ok(guard) = b.modules.lock() {
if let Some(m) = guard.get(&p) {
return Ok(m.clone());
}
}
let ptx = cudarc::nvrtc::compile_ptx(score_block_kernel_source(p))
.gpu_ctx_with(|err| format!("sparse_dict score-block NVRTC (P={p}): {err}"))?;
let module = b
.ctx
.load_module(ptx)
.gpu_ctx("sparse_dict score-block module load")?;
if let Ok(mut guard) = b.modules.lock() {
guard.entry(p).or_insert_with(|| module.clone());
}
Ok(module)
}
pub(super) fn score_block_device(
rows: ArrayView2<'_, f32>,
atoms: ArrayView2<'_, f32>,
) -> Result<Vec<f32>, GpuError> {
let n_rows = rows.nrows();
let n_atoms = atoms.nrows();
let p = rows.ncols();
if p != atoms.ncols() {
return Err(gam_gpu::gpu_err!(
"sparse_dict score-block: P mismatch rows={p} atoms={}",
atoms.ncols()
));
}
if n_rows == 0 || n_atoms == 0 || p == 0 {
return Ok(vec![0.0f32; n_rows * n_atoms]);
}
let b = backend()?;
let module = module_for(b, p)?;
let func = module
.load_function("sparse_dict_score_block")
.gpu_ctx("sparse_dict score-block load_function")?;
let stream = b.stream.clone();
let rows_host: Vec<f32> = rows.iter().copied().collect();
let atoms_host: Vec<f32> = atoms.iter().copied().collect();
assert_eq!(rows_host.len(), n_rows * p, "score-block rows flatten length");
assert_eq!(
atoms_host.len(),
n_atoms * p,
"score-block atoms flatten length"
);
let rows_dev = stream
.clone_htod(&rows_host)
.gpu_ctx("sparse_dict score-block htod rows")?;
let atoms_dev = stream
.clone_htod(&atoms_host)
.gpu_ctx("sparse_dict score-block htod atoms")?;
let mut scores_dev = stream
.alloc_zeros::<f32>(n_rows * n_atoms)
.gpu_ctx("sparse_dict score-block alloc scores")?;
let n_rows_i32 = i32::try_from(n_rows)
.map_err(|_| gam_gpu::gpu_err!("sparse_dict score-block n_rows={n_rows} overflows i32"))?;
let n_atoms_i32 = i32::try_from(n_atoms).map_err(|_| {
gam_gpu::gpu_err!("sparse_dict score-block n_atoms={n_atoms} overflows i32")
})?;
let total = n_rows * n_atoms;
let block: u32 = 256;
let grid: u32 = u32::try_from(total.div_ceil(block as usize))
.map_err(|_| gam_gpu::gpu_err!("sparse_dict score-block grid overflow"))?;
let cfg = LaunchConfig {
grid_dim: (grid, 1, 1),
block_dim: (block, 1, 1),
shared_mem_bytes: 0,
};
let mut builder = stream.launch_builder(&func);
builder
.arg(&rows_dev)
.arg(&atoms_dev)
.arg(&n_rows_i32)
.arg(&n_atoms_i32)
.arg(&mut scores_dev);
unsafe { builder.launch(cfg) }.gpu_ctx("sparse_dict score-block launch")?;
let mut scores = vec![0.0f32; n_rows * n_atoms];
stream
.memcpy_dtoh(&scores_dev, &mut scores)
.gpu_ctx("sparse_dict score-block dtoh scores")?;
stream
.synchronize()
.gpu_ctx("sparse_dict score-block synchronize")?;
Ok(scores)
}
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::Array2;
fn fixture(n_rows: usize, n_atoms: usize, p: usize) -> (Array2<f32>, Array2<f32>) {
let rows = Array2::from_shape_fn((n_rows, p), |(i, c)| {
(((i * 31 + c * 17) as f32) * 0.013).sin() * 0.9
});
let mut atoms = Array2::from_shape_fn((n_atoms, p), |(a, c)| {
(((a * 7 + c * 5) as f32) * 0.011).cos()
});
for mut row in atoms.outer_iter_mut() {
let norm = row.iter().map(|v| v * v).sum::<f32>().sqrt().max(1e-12);
row.mapv_inplace(|v| v / norm);
}
(rows, atoms)
}
#[test]
fn cpu_score_block_matches_score_row_tile() {
use crate::sparse_dict::scoring::score_row_tile;
let (rows, atoms) = fixture(5, 9, 7);
let block = score_block_cpu(rows.view(), atoms.view());
for r in 0..rows.nrows() {
for a in 0..atoms.nrows() {
let mut acc = 0.0f32;
for c in 0..rows.ncols() {
acc += rows[[r, c]] * atoms[[a, c]];
}
assert_eq!(
block[r * atoms.nrows() + a].to_bits(),
acc.to_bits(),
"block oracle vs raw acc differ at r={r} a={a}"
);
}
}
let mut sel = crate::sparse_dict::scoring::TopSSelector::new(3);
score_row_tile(rows.row(0), atoms.view(), 0, &mut sel);
let picked = sel.finish();
assert!(picked.len() <= 3 && !picked.is_empty());
}
#[cfg(target_os = "linux")]
#[test]
fn device_route_minibatch_matches_cpu_top_s_online() {
use crate::sparse_dict::scoring::top_s_online;
let m = 512usize;
let k = 4096usize; let p = 48usize;
let s = 4usize;
let tile = 1024usize;
assert!(m * k >= DEVICE_SCORE_BLOCK_MIN_ELEMS);
let (rows, atoms) = fixture(m, k, p);
let cpu: Vec<Vec<(u32, f32)>> = rows
.outer_iter()
.map(|row| top_s_online(row, atoms.view(), s, tile))
.collect();
match route_minibatch_required(
rows.view(),
atoms.view(),
s,
tile,
gam_gpu::GpuMode::Required,
) {
Ok((routed, path)) => {
assert_eq!(
path,
ScoreBlockPath::Device,
"Required succeeded but reported CPU — device did not engage"
);
assert_eq!(routed.len(), cpu.len());
for (r, (dev_sel, cpu_sel)) in routed.iter().zip(&cpu).enumerate() {
assert_eq!(
dev_sel.len(),
cpu_sel.len(),
"row {r}: selection length differs"
);
for (j, ((da, ds), (ca, cs))) in dev_sel.iter().zip(cpu_sel).enumerate() {
assert_eq!(da, ca, "row {r} slot {j}: atom differs dev={da} cpu={ca}");
assert_eq!(
ds.to_bits(),
cs.to_bits(),
"row {r} slot {j}: score bits differ dev={ds} cpu={cs}"
);
}
}
}
Err(err) => {
assert!(
gam_gpu::GpuRuntime::global().is_none(),
"Required errored despite a live CUDA runtime: {err}"
);
let (routed, path) = route_minibatch_required(
rows.view(),
atoms.view(),
s,
tile,
gam_gpu::GpuMode::Auto,
)
.expect("Auto must not error on a device-absent host");
assert_eq!(path, ScoreBlockPath::Cpu);
assert_eq!(routed, cpu);
}
}
}
#[cfg(target_os = "linux")]
#[test]
fn device_route_at_issue_target_k_32k_is_bit_identical() {
use crate::sparse_dict::scoring::top_s_online;
let m = 256usize;
let k = 32_768usize; let p = 64usize;
let s = 4usize;
let tile = 2048usize;
assert!(m * k >= DEVICE_SCORE_BLOCK_MIN_ELEMS);
let (rows, atoms) = fixture(m, k, p);
let cpu: Vec<Vec<(u32, f32)>> = rows
.outer_iter()
.map(|row| top_s_online(row, atoms.view(), s, tile))
.collect();
match route_minibatch_required(rows.view(), atoms.view(), s, tile, gam_gpu::GpuMode::Required)
{
Ok((routed, path)) => {
assert_eq!(
path,
ScoreBlockPath::Device,
"Required succeeded at K=32k but reported CPU — device did not engage"
);
assert_eq!(routed.len(), cpu.len());
for (r, (dev_sel, cpu_sel)) in routed.iter().zip(&cpu).enumerate() {
assert_eq!(dev_sel.len(), cpu_sel.len(), "row {r}: selection length differs");
for (j, ((da, ds), (ca, cs))) in dev_sel.iter().zip(cpu_sel).enumerate() {
assert_eq!(da, ca, "K=32k row {r} slot {j}: atom differs dev={da} cpu={ca}");
assert_eq!(
ds.to_bits(),
cs.to_bits(),
"K=32k row {r} slot {j}: score bits differ dev={ds} cpu={cs}"
);
}
}
}
Err(err) => {
assert!(
gam_gpu::GpuRuntime::global().is_none(),
"Required errored at K=32k despite a live CUDA runtime: {err}"
);
let (routed, path) = route_minibatch_required(
rows.view(),
atoms.view(),
s,
tile,
gam_gpu::GpuMode::Auto,
)
.expect("Auto must not error on a device-absent host");
assert_eq!(path, ScoreBlockPath::Cpu);
assert_eq!(routed, cpu);
}
}
}
#[cfg(target_os = "linux")]
#[test]
fn device_score_block_is_bit_identical_to_cpu_when_available() {
let n_rows = 256;
let n_atoms = 4096; let p = 48;
assert!(n_rows * n_atoms >= DEVICE_SCORE_BLOCK_MIN_ELEMS);
let (rows, atoms) = fixture(n_rows, n_atoms, p);
let cpu = score_block_cpu(rows.view(), atoms.view());
match score_block_required(rows.view(), atoms.view(), gam_gpu::GpuMode::Required) {
Ok((got, path)) => {
assert_eq!(
path,
ScoreBlockPath::Device,
"Required succeeded but reported CPU — device did not engage"
);
assert_eq!(got.len(), cpu.len());
for (i, (g, c)) in got.iter().zip(&cpu).enumerate() {
assert_eq!(
g.to_bits(),
c.to_bits(),
"device vs CPU score-block bit mismatch at {i}: dev={g} cpu={c}"
);
}
}
Err(err) => {
assert!(
gam_gpu::GpuRuntime::global().is_none(),
"Required errored despite a live CUDA runtime: {err}"
);
let (got, path) =
score_block_required(rows.view(), atoms.view(), gam_gpu::GpuMode::Auto)
.expect("Auto must not error on a device-absent host");
assert_eq!(path, ScoreBlockPath::Cpu);
assert_eq!(got, cpu);
}
}
}
}