#![cfg(feature = "cuda")]
use cudarc::driver::{CudaSlice, LaunchConfig, PushKernelArg};
use crate::buffer::CudaBuffer;
use crate::device::GpuDevice;
use crate::error::{GpuError, GpuResult};
use crate::kernels::gpu_cumsum;
use crate::module_cache::get_or_compile;
use crate::transfer::{alloc_zeros_f32, alloc_zeros_f64, gpu_to_cpu};
const BLOCK_SIZE: u32 = 256;
const SIDE_LEFT: u32 = 0; const SIDE_RIGHT: u32 = 1;
fn launch_1d(n: usize) -> LaunchConfig {
let grid = ((n as u32).saturating_add(BLOCK_SIZE - 1)) / BLOCK_SIZE;
LaunchConfig {
grid_dim: (grid.max(1), 1, 1),
block_dim: (BLOCK_SIZE, 1, 1),
shared_mem_bytes: 0,
}
}
const SEARCHSORTED_F32_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry searchsorted_f32_kernel(
.param .u64 vals_ptr,
.param .u64 bounds_ptr,
.param .u64 out_ptr,
.param .u32 n_vals,
.param .u32 n_bounds,
.param .u32 right
) {
.reg .u32 %tid_r, %bid_r, %bdim_r, %t, %nv, %nb, %rt;
.reg .u32 %lo, %hi, %mid, %half, %mid1;
.reg .u64 %vals_p, %bnd_p, %out_p, %off, %addr;
.reg .f32 %v, %bv;
.reg .s64 %res;
.reg .pred %p_oob, %p_loop, %p_is_right, %p_not_right, %p_adv;
.reg .pred %p_ge, %p_gt, %p_nge, %p_ngt, %p_a, %p_b;
ld.param.u64 %vals_p, [vals_ptr];
ld.param.u64 %bnd_p, [bounds_ptr];
ld.param.u64 %out_p, [out_ptr];
ld.param.u32 %nv, [n_vals];
ld.param.u32 %nb, [n_bounds];
ld.param.u32 %rt, [right];
mov.u32 %tid_r, %tid.x;
mov.u32 %bid_r, %ctaid.x;
mov.u32 %bdim_r, %ntid.x;
mad.lo.u32 %t, %bid_r, %bdim_r, %tid_r;
setp.ge.u32 %p_oob, %t, %nv;
@%p_oob bra DONE;
// v = vals[t]
cvt.u64.u32 %off, %t;
shl.b64 %off, %off, 2;
add.u64 %addr, %vals_p, %off;
ld.global.f32 %v, [%addr];
setp.ne.u32 %p_is_right, %rt, 0; // p_is_right = (right != 0)
setp.eq.u32 %p_not_right, %rt, 0; // p_not_right = (right == 0)
mov.u32 %lo, 0;
mov.u32 %hi, %nb;
LOOP:
setp.ge.u32 %p_loop, %lo, %hi; // exit when lo >= hi
@%p_loop bra STORE;
// mid = lo + ((hi - lo) >> 1)
sub.u32 %half, %hi, %lo;
shr.u32 %half, %half, 1;
add.u32 %mid, %lo, %half;
// bv = bounds[mid]
cvt.u64.u32 %off, %mid;
shl.b64 %off, %off, 2;
add.u64 %addr, %bnd_p, %off;
ld.global.f32 %bv, [%addr];
// advance predicate (no `selp.pred`; build it with predicate logic),
// mirroring upstream aten/src/ATen/native/cuda/Bucketization.cu:33,51:
// left (lower_bound): advance while `!(bv >= v)` (Bucketization.cu:33)
// right (upper_bound): advance while `!(bv > v)` (Bucketization.cu:51)
// p_adv = (right & !(bv > v)) | (!right & !(bv >= v))
// `setp.ge`/`setp.gt` are ORDERED (false for NaN), so the negation is TRUE
// for a NaN value -> always advance -> lo = len, matching torch. For finite
// operands `!(bv >= v) == (bv < v)` and `!(bv > v) == (bv <= v)`, so the
// finite tie/dup/oob cases are byte-identical to the prior setp.lt/le form.
setp.ge.f32 %p_ge, %bv, %v; // p_ge = (bv >= v), ordered (false for NaN)
setp.gt.f32 %p_gt, %bv, %v; // p_gt = (bv > v), ordered (false for NaN)
not.pred %p_nge, %p_ge; // p_nge = !(bv >= v) (true for NaN)
not.pred %p_ngt, %p_gt; // p_ngt = !(bv > v) (true for NaN)
and.pred %p_a, %p_is_right, %p_ngt;
and.pred %p_b, %p_not_right, %p_nge;
or.pred %p_adv, %p_a, %p_b;
// if advance: lo = mid + 1 ; else: hi = mid
add.u32 %mid1, %mid, 1;
@%p_adv mov.u32 %lo, %mid1;
@!%p_adv mov.u32 %hi, %mid;
bra LOOP;
STORE:
cvt.s64.u32 %res, %lo;
cvt.u64.u32 %off, %t;
shl.b64 %off, %off, 3;
add.u64 %addr, %out_p, %off;
st.global.s64 [%addr], %res;
DONE:
ret;
}
";
const SEARCHSORTED_F64_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry searchsorted_f64_kernel(
.param .u64 vals_ptr,
.param .u64 bounds_ptr,
.param .u64 out_ptr,
.param .u32 n_vals,
.param .u32 n_bounds,
.param .u32 right
) {
.reg .u32 %tid_r, %bid_r, %bdim_r, %t, %nv, %nb, %rt;
.reg .u32 %lo, %hi, %mid, %half, %mid1;
.reg .u64 %vals_p, %bnd_p, %out_p, %off, %addr;
.reg .f64 %v, %bv;
.reg .s64 %res;
.reg .pred %p_oob, %p_loop, %p_is_right, %p_not_right, %p_adv;
.reg .pred %p_ge, %p_gt, %p_nge, %p_ngt, %p_a, %p_b;
ld.param.u64 %vals_p, [vals_ptr];
ld.param.u64 %bnd_p, [bounds_ptr];
ld.param.u64 %out_p, [out_ptr];
ld.param.u32 %nv, [n_vals];
ld.param.u32 %nb, [n_bounds];
ld.param.u32 %rt, [right];
mov.u32 %tid_r, %tid.x;
mov.u32 %bid_r, %ctaid.x;
mov.u32 %bdim_r, %ntid.x;
mad.lo.u32 %t, %bid_r, %bdim_r, %tid_r;
setp.ge.u32 %p_oob, %t, %nv;
@%p_oob bra DONE;
cvt.u64.u32 %off, %t;
shl.b64 %off, %off, 3;
add.u64 %addr, %vals_p, %off;
ld.global.f64 %v, [%addr];
setp.ne.u32 %p_is_right, %rt, 0;
setp.eq.u32 %p_not_right, %rt, 0;
mov.u32 %lo, 0;
mov.u32 %hi, %nb;
LOOP:
setp.ge.u32 %p_loop, %lo, %hi;
@%p_loop bra STORE;
sub.u32 %half, %hi, %lo;
shr.u32 %half, %half, 1;
add.u32 %mid, %lo, %half;
cvt.u64.u32 %off, %mid;
shl.b64 %off, %off, 3;
add.u64 %addr, %bnd_p, %off;
ld.global.f64 %bv, [%addr];
// advance predicate mirroring aten/src/ATen/native/cuda/Bucketization.cu:33,51:
// left (lower_bound): advance while `!(bv >= v)` (Bucketization.cu:33)
// right (upper_bound): advance while `!(bv > v)` (Bucketization.cu:51)
// `setp.ge`/`setp.gt` are ORDERED (false for NaN) -> negation TRUE for NaN ->
// always advance -> lo = len, matching torch. Finite operands unchanged.
setp.ge.f64 %p_ge, %bv, %v; // p_ge = (bv >= v), ordered (false for NaN)
setp.gt.f64 %p_gt, %bv, %v; // p_gt = (bv > v), ordered (false for NaN)
not.pred %p_nge, %p_ge; // p_nge = !(bv >= v) (true for NaN)
not.pred %p_ngt, %p_gt; // p_ngt = !(bv > v) (true for NaN)
and.pred %p_a, %p_is_right, %p_ngt;
and.pred %p_b, %p_not_right, %p_nge;
or.pred %p_adv, %p_a, %p_b;
add.u32 %mid1, %mid, 1;
@%p_adv mov.u32 %lo, %mid1;
@!%p_adv mov.u32 %hi, %mid;
bra LOOP;
STORE:
cvt.s64.u32 %res, %lo;
cvt.u64.u32 %off, %t;
shl.b64 %off, %off, 3;
add.u64 %addr, %out_p, %off;
st.global.s64 [%addr], %res;
DONE:
ret;
}
";
fn launch_searchsorted<V>(
values: &CudaSlice<V>,
boundaries: &CudaSlice<V>,
n_vals: usize,
n_bounds: usize,
right: bool,
device: &GpuDevice,
ptx: &'static str,
kernel_name: &'static str,
) -> GpuResult<CudaSlice<i64>>
where
V: cudarc::driver::DeviceRepr,
{
if values.len() < n_vals {
return Err(GpuError::LengthMismatch {
a: values.len(),
b: n_vals,
});
}
if boundaries.len() < n_bounds {
return Err(GpuError::LengthMismatch {
a: boundaries.len(),
b: n_bounds,
});
}
if n_vals > u32::MAX as usize || n_bounds > u32::MAX as usize {
return Err(GpuError::LengthMismatch {
a: n_vals,
b: u32::MAX as usize,
});
}
let stream = device.stream();
if n_vals == 0 {
return Ok(stream.alloc_zeros::<i64>(0)?);
}
let ctx = device.context();
let f = get_or_compile(ctx, ptx, kernel_name, device.ordinal() as u32).map_err(|e| {
GpuError::PtxCompileFailed {
kernel: kernel_name,
source: e,
}
})?;
let mut out = stream.alloc_zeros::<i64>(n_vals)?;
let cfg = launch_1d(n_vals);
let n_vals_u = n_vals as u32;
let n_bounds_u = n_bounds as u32;
let right_u = if right { SIDE_RIGHT } else { SIDE_LEFT };
unsafe {
stream
.launch_builder(&f)
.arg(values)
.arg(boundaries)
.arg(&mut out)
.arg(&n_vals_u)
.arg(&n_bounds_u)
.arg(&right_u)
.launch(cfg)?;
}
Ok(out)
}
pub fn gpu_searchsorted_f32(
values: &CudaSlice<f32>,
boundaries: &CudaSlice<f32>,
n_vals: usize,
n_bounds: usize,
right: bool,
device: &GpuDevice,
) -> GpuResult<CudaSlice<i64>> {
launch_searchsorted(
values,
boundaries,
n_vals,
n_bounds,
right,
device,
SEARCHSORTED_F32_PTX,
"searchsorted_f32_kernel",
)
}
pub fn gpu_searchsorted_f64(
values: &CudaSlice<f64>,
boundaries: &CudaSlice<f64>,
n_vals: usize,
n_bounds: usize,
right: bool,
device: &GpuDevice,
) -> GpuResult<CudaSlice<i64>> {
launch_searchsorted(
values,
boundaries,
n_vals,
n_bounds,
right,
device,
SEARCHSORTED_F64_PTX,
"searchsorted_f64_kernel",
)
}
macro_rules! topk_ptx {
($entry:literal, $tyld:literal, $shift:literal) => {
concat!(
".version 7.0\n.target sm_52\n.address_size 64\n\n.visible .entry ",
$entry,
"(\n",
" .param .u64 in_ptr,\n",
" .param .u64 vals_ptr,\n",
" .param .u64 idx_ptr,\n",
" .param .u32 outer,\n",
" .param .u32 dim,\n",
" .param .u32 k,\n",
" .param .u32 largest\n",
") {\n",
" .reg .u32 %tid_r, %bid_r, %bdim_r, %s, %no, %nd, %nk, %lg;\n",
" .reg .u32 %j, %i, %prev_idx, %best_idx, %cur_idx;\n",
" .reg .u64 %in_p, %vp, %ip, %slice_off, %off, %addr, %tmp64;\n",
" .reg .",
$tyld,
" %prev_val, %best_val, %cur_val;\n",
" .reg .s64 %ridx;\n",
" .reg .pred %p_oob, %p_jloop, %p_iloop, %p_lg, %p_have, %p_first;\n",
" .reg .pred %p_elig, %p_beat, %p_vgt, %p_vlt, %p_veq, %p_vsel, %p_idx;\n",
" .reg .pred %p_pgt, %p_plt, %p_peq, %p_psel, %p_pidx, %p_upd;\n",
" .reg .pred %p_cnan, %p_pnan, %p_bnan, %p_na, %p_nb;\n",
"\n",
" ld.param.u64 %in_p, [in_ptr];\n",
" ld.param.u64 %vp, [vals_ptr];\n",
" ld.param.u64 %ip, [idx_ptr];\n",
" ld.param.u32 %no, [outer];\n",
" ld.param.u32 %nd, [dim];\n",
" ld.param.u32 %nk, [k];\n",
" ld.param.u32 %lg, [largest];\n",
"\n",
" mov.u32 %tid_r, %tid.x;\n",
" mov.u32 %bid_r, %ctaid.x;\n",
" mov.u32 %bdim_r, %ntid.x;\n",
" mad.lo.u32 %s, %bid_r, %bdim_r, %tid_r;\n",
" setp.ge.u32 %p_oob, %s, %no;\n",
" @%p_oob bra DONE;\n",
"\n",
" setp.ne.u32 %p_lg, %lg, 0; // p_lg = largest\n",
" // slice_off = s * dim (in elements)\n",
" mul.lo.u32 %i, %s, %nd;\n",
" cvt.u64.u32 %slice_off, %i;\n",
" shl.b64 %slice_off, %slice_off, ",
$shift,
";\n",
"\n",
" mov.u32 %j, 0;\n",
"JLOOP:\n",
" setp.ge.u32 %p_jloop, %j, %nk;\n",
" @%p_jloop bra DONE;\n",
"\n",
" setp.eq.u32 %p_first, %j, 0; // j == 0 -> no previous pick\n",
" mov.pred %p_have, 0; // have_best = false\n",
" mov.u32 %i, 0;\n",
"ILOOP:\n",
" setp.ge.u32 %p_iloop, %i, %nd;\n",
" @%p_iloop bra ISTORE;\n",
"\n",
" // cur_val = in[slice + i]\n",
" cvt.u64.u32 %off, %i;\n",
" shl.b64 %off, %off, ",
$shift,
";\n",
" add.u64 %addr, %in_p, %slice_off;\n",
" add.u64 %addr, %addr, %off;\n",
" ld.global.",
$tyld,
" %cur_val, [%addr];\n",
" mov.u32 %cur_idx, %i;\n",
" testp.notanumber.",
$tyld,
" %p_cnan, %cur_val; // p_cnan = isnan(cur)\n",
"\n",
" // eligibility: for j==0 every element is eligible. Otherwise eligible iff\n",
" // `prev` ranks strictly before `cur` in selection order. NaN ordering\n",
" // mirrors torch's GTOp/LTOp comparator with handleNaN=true\n",
" // (aten/src/ATen/native/cuda/SortingCommon.cuh:47-60): NaN compares\n",
" // GREATER than every finite/inf value. So `prev outranks cur`:\n",
" // largest: (isnan(prev) && !isnan(cur)) || (prev > cur)\n",
" // smallest: (isnan(cur) && !isnan(prev)) || (prev < cur)\n",
" // equal-rank (so the ascending-index tie-break applies, incl. NaN==NaN):\n",
" // (isnan(prev) && isnan(cur)) || (prev == cur)\n",
" // `setp.gt/lt/eq` are ORDERED (false if either operand is NaN), so the\n",
" // finite terms need no extra masking; the NaN terms add the ordering.\n",
" testp.notanumber.",
$tyld,
" %p_pnan, %prev_val; // p_pnan = isnan(prev)\n",
" setp.gt.",
$tyld,
" %p_pgt, %prev_val, %cur_val;\n",
" setp.lt.",
$tyld,
" %p_plt, %prev_val, %cur_val;\n",
" setp.eq.",
$tyld,
" %p_peq, %prev_val, %cur_val;\n",
" // NaN-greater terms\n",
" not.pred %p_na, %p_cnan; // !isnan(cur)\n",
" and.pred %p_na, %p_pnan, %p_na; // isnan(prev) && !isnan(cur)\n",
" or.pred %p_pgt, %p_pgt, %p_na; // largest: prev outranks cur\n",
" not.pred %p_nb, %p_pnan; // !isnan(prev)\n",
" and.pred %p_nb, %p_cnan, %p_nb; // isnan(cur) && !isnan(prev)\n",
" or.pred %p_plt, %p_plt, %p_nb; // smallest: prev outranks cur\n",
" and.pred %p_na, %p_pnan, %p_cnan; // isnan(prev) && isnan(cur)\n",
" or.pred %p_peq, %p_peq, %p_na; // equal-rank (incl. NaN==NaN)\n",
" // p_psel = largest ? p_pgt : p_plt\n",
" and.pred %p_psel, %p_lg, %p_pgt;\n",
" not.pred %p_idx, %p_lg;\n",
" and.pred %p_pidx, %p_idx, %p_plt;\n",
" or.pred %p_psel, %p_psel, %p_pidx;\n",
" setp.lt.u32 %p_pidx, %prev_idx, %cur_idx;\n",
" and.pred %p_pidx, %p_peq, %p_pidx; // equal-rank && prev_idx<cur_idx\n",
" or.pred %p_elig, %p_psel, %p_pidx;\n",
" or.pred %p_elig, %p_elig, %p_first; // j==0 -> always eligible\n",
" @!%p_elig bra INEXT;\n",
"\n",
" // candidate beats current best? Same NaN-as-maximum comparator:\n",
" // if !have_best -> yes\n",
" // else largest: (isnan(cur) && !isnan(best)) || (cur > best)\n",
" // || (equal-rank && cur_idx < best_idx)\n",
" // smallest: (isnan(best) && !isnan(cur)) || (cur < best)\n",
" // || (equal-rank && cur_idx < best_idx)\n",
" not.pred %p_upd, %p_have; // !have_best\n",
" testp.notanumber.",
$tyld,
" %p_bnan, %best_val; // p_bnan = isnan(best)\n",
" setp.gt.",
$tyld,
" %p_vgt, %cur_val, %best_val;\n",
" setp.lt.",
$tyld,
" %p_vlt, %cur_val, %best_val;\n",
" setp.eq.",
$tyld,
" %p_veq, %cur_val, %best_val;\n",
" not.pred %p_na, %p_bnan; // !isnan(best)\n",
" and.pred %p_na, %p_cnan, %p_na; // isnan(cur) && !isnan(best)\n",
" or.pred %p_vgt, %p_vgt, %p_na; // largest: cur outranks best\n",
" not.pred %p_nb, %p_cnan; // !isnan(cur)\n",
" and.pred %p_nb, %p_bnan, %p_nb; // isnan(best) && !isnan(cur)\n",
" or.pred %p_vlt, %p_vlt, %p_nb; // smallest: cur outranks best\n",
" and.pred %p_na, %p_cnan, %p_bnan; // isnan(cur) && isnan(best)\n",
" or.pred %p_veq, %p_veq, %p_na; // equal-rank (incl. NaN==NaN)\n",
" and.pred %p_vsel, %p_lg, %p_vgt;\n",
" not.pred %p_idx, %p_lg;\n",
" and.pred %p_idx, %p_idx, %p_vlt;\n",
" or.pred %p_vsel, %p_vsel, %p_idx;\n",
" setp.lt.u32 %p_idx, %cur_idx, %best_idx;\n",
" and.pred %p_idx, %p_veq, %p_idx;\n",
" or.pred %p_beat, %p_vsel, %p_idx;\n",
" and.pred %p_beat, %p_beat, %p_have; // only meaningful when have_best\n",
" or.pred %p_upd, %p_upd, %p_beat;\n",
" @!%p_upd bra INEXT;\n",
"\n",
" mov.",
$tyld,
" %best_val, %cur_val;\n",
" mov.u32 %best_idx, %cur_idx;\n",
" mov.pred %p_have, 1;\n",
"\n",
"INEXT:\n",
" add.u32 %i, %i, 1;\n",
" bra ILOOP;\n",
"\n",
"ISTORE:\n",
" // out position = s * k + j\n",
" mul.lo.u32 %cur_idx, %s, %nk;\n",
" add.u32 %cur_idx, %cur_idx, %j;\n",
" cvt.u64.u32 %off, %cur_idx;\n",
" // store value\n",
" shl.b64 %addr, %off, ",
$shift,
";\n",
" add.u64 %addr, %vp, %addr;\n",
" st.global.",
$tyld,
" [%addr], %best_val;\n",
" // store index (i64)\n",
" shl.b64 %tmp64, %off, 3;\n",
" add.u64 %addr, %ip, %tmp64;\n",
" cvt.s64.u32 %ridx, %best_idx;\n",
" st.global.s64 [%addr], %ridx;\n",
" // prev = best (for next j)\n",
" mov.",
$tyld,
" %prev_val, %best_val;\n",
" mov.u32 %prev_idx, %best_idx;\n",
"\n",
" add.u32 %j, %j, 1;\n",
" bra JLOOP;\n",
"\n",
"DONE:\n",
" ret;\n",
"}\n"
)
};
}
const TOPK_F32_PTX: &str = topk_ptx!("topk_f32_kernel", "f32", "2");
const TOPK_F64_PTX: &str = topk_ptx!("topk_f64_kernel", "f64", "3");
fn launch_topk_config(outer: usize) -> LaunchConfig {
let grid = ((outer as u32).saturating_add(BLOCK_SIZE - 1)) / BLOCK_SIZE;
LaunchConfig {
grid_dim: (grid.max(1), 1, 1),
block_dim: (BLOCK_SIZE, 1, 1),
shared_mem_bytes: 0,
}
}
#[allow(clippy::too_many_arguments)]
fn launch_topk<V>(
input: &CudaSlice<V>,
outer: usize,
dim: usize,
k: usize,
largest: bool,
device: &GpuDevice,
ptx: &'static str,
kernel_name: &'static str,
) -> GpuResult<(CudaSlice<V>, CudaSlice<i64>)>
where
V: cudarc::driver::DeviceRepr + cudarc::driver::ValidAsZeroBits,
{
if k > dim {
return Err(GpuError::LengthMismatch { a: k, b: dim });
}
if input.len() < outer.saturating_mul(dim) {
return Err(GpuError::LengthMismatch {
a: input.len(),
b: outer.saturating_mul(dim),
});
}
if outer > u32::MAX as usize || dim > u32::MAX as usize || k > u32::MAX as usize {
return Err(GpuError::LengthMismatch {
a: outer.max(dim).max(k),
b: u32::MAX as usize,
});
}
let stream = device.stream();
let n_out = outer.saturating_mul(k);
if n_out == 0 {
return Ok((stream.alloc_zeros::<V>(0)?, stream.alloc_zeros::<i64>(0)?));
}
let ctx = device.context();
let f = get_or_compile(ctx, ptx, kernel_name, device.ordinal() as u32).map_err(|e| {
GpuError::PtxCompileFailed {
kernel: kernel_name,
source: e,
}
})?;
let mut out_vals = stream.alloc_zeros::<V>(n_out)?;
let mut out_idx = stream.alloc_zeros::<i64>(n_out)?;
let cfg = launch_topk_config(outer);
let outer_u = outer as u32;
let dim_u = dim as u32;
let k_u = k as u32;
let largest_u: u32 = u32::from(largest);
unsafe {
stream
.launch_builder(&f)
.arg(input)
.arg(&mut out_vals)
.arg(&mut out_idx)
.arg(&outer_u)
.arg(&dim_u)
.arg(&k_u)
.arg(&largest_u)
.launch(cfg)?;
}
Ok((out_vals, out_idx))
}
pub fn gpu_topk_f32(
input: &CudaSlice<f32>,
outer: usize,
dim: usize,
k: usize,
largest: bool,
device: &GpuDevice,
) -> GpuResult<(CudaSlice<f32>, CudaSlice<i64>)> {
launch_topk(
input,
outer,
dim,
k,
largest,
device,
TOPK_F32_PTX,
"topk_f32_kernel",
)
}
pub fn gpu_topk_f64(
input: &CudaSlice<f64>,
outer: usize,
dim: usize,
k: usize,
largest: bool,
device: &GpuDevice,
) -> GpuResult<(CudaSlice<f64>, CudaSlice<i64>)> {
launch_topk(
input,
outer,
dim,
k,
largest,
device,
TOPK_F64_PTX,
"topk_f64_kernel",
)
}
const HISTC_F32_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry histc_f32_kernel(
.param .u64 in_ptr,
.param .u64 out_ptr,
.param .u32 n,
.param .u32 nbins,
.param .f32 minv,
.param .f32 maxv
) {
.reg .u32 %tid_r, %bid_r, %bdim_r, %t, %nn, %nb, %bin, %bin1;
.reg .u64 %in_p, %out_p, %off, %addr;
.reg .f32 %v, %minv, %maxv, %range, %rel, %scaled, %nbf, %binf, %one;
.reg .pred %p_oob, %p_lo, %p_hi, %p_in, %p_last;
ld.param.u64 %in_p, [in_ptr];
ld.param.u64 %out_p, [out_ptr];
ld.param.u32 %nn, [n];
ld.param.u32 %nb, [nbins];
ld.param.f32 %minv, [minv];
ld.param.f32 %maxv, [maxv];
mov.u32 %tid_r, %tid.x;
mov.u32 %bid_r, %ctaid.x;
mov.u32 %bdim_r, %ntid.x;
mad.lo.u32 %t, %bid_r, %bdim_r, %tid_r;
setp.ge.u32 %p_oob, %t, %nn;
@%p_oob bra DONE;
// v = in[t]
cvt.u64.u32 %off, %t;
shl.b64 %off, %off, 2;
add.u64 %addr, %in_p, %off;
ld.global.f32 %v, [%addr];
// guard (SummaryOps.cu:92): count only when (v >= min && v <= max).
// setp.ge/le are ORDERED (false for NaN) -> NaN skipped, matching torch.
setp.ge.f32 %p_lo, %v, %minv;
setp.le.f32 %p_hi, %v, %maxv;
and.pred %p_in, %p_lo, %p_hi;
@!%p_in bra DONE;
// bin = (int)((v - min) * nbins / (max - min)) (SummaryOps.cu:41)
sub.f32 %rel, %v, %minv;
sub.f32 %range, %maxv, %minv;
cvt.rn.f32.u32 %nbf, %nb;
mul.f32 %scaled, %rel, %nbf;
div.rn.f32 %binf, %scaled, %range;
// truncate toward zero -> u32 bin. rel >= 0 here so trunc == floor.
cvt.rzi.u32.f32 %bin, %binf;
// if (bin == nbins) bin -= 1; (SummaryOps.cu:47-48, last bin [min,max])
setp.eq.u32 %p_last, %bin, %nb;
sub.u32 %bin1, %bin, 1;
@%p_last mov.u32 %bin, %bin1;
// atomicAdd(&out[bin], 1.0f)
cvt.u64.u32 %off, %bin;
shl.b64 %off, %off, 2;
add.u64 %addr, %out_p, %off;
mov.f32 %one, 0f3F800000;
red.global.add.f32 [%addr], %one;
DONE:
ret;
}
";
const HISTC_F64_PTX: &str = "\
.version 7.0
.target sm_60
.address_size 64
.visible .entry histc_f64_kernel(
.param .u64 in_ptr,
.param .u64 out_ptr,
.param .u32 n,
.param .u32 nbins,
.param .f64 minv,
.param .f64 maxv
) {
.reg .u32 %tid_r, %bid_r, %bdim_r, %t, %nn, %nb, %bin, %bin1;
.reg .u64 %in_p, %out_p, %off, %addr;
.reg .f64 %v, %minv, %maxv, %range, %rel, %scaled, %nbf, %binf, %one;
.reg .pred %p_oob, %p_lo, %p_hi, %p_in, %p_last;
ld.param.u64 %in_p, [in_ptr];
ld.param.u64 %out_p, [out_ptr];
ld.param.u32 %nn, [n];
ld.param.u32 %nb, [nbins];
ld.param.f64 %minv, [minv];
ld.param.f64 %maxv, [maxv];
mov.u32 %tid_r, %tid.x;
mov.u32 %bid_r, %ctaid.x;
mov.u32 %bdim_r, %ntid.x;
mad.lo.u32 %t, %bid_r, %bdim_r, %tid_r;
setp.ge.u32 %p_oob, %t, %nn;
@%p_oob bra DONE;
cvt.u64.u32 %off, %t;
shl.b64 %off, %off, 3;
add.u64 %addr, %in_p, %off;
ld.global.f64 %v, [%addr];
setp.ge.f64 %p_lo, %v, %minv;
setp.le.f64 %p_hi, %v, %maxv;
and.pred %p_in, %p_lo, %p_hi;
@!%p_in bra DONE;
sub.f64 %rel, %v, %minv;
sub.f64 %range, %maxv, %minv;
cvt.rn.f64.u32 %nbf, %nb;
mul.f64 %scaled, %rel, %nbf;
div.rn.f64 %binf, %scaled, %range;
cvt.rzi.u32.f64 %bin, %binf;
setp.eq.u32 %p_last, %bin, %nb;
sub.u32 %bin1, %bin, 1;
@%p_last mov.u32 %bin, %bin1;
cvt.u64.u32 %off, %bin;
shl.b64 %off, %off, 3;
add.u64 %addr, %out_p, %off;
mov.f64 %one, 0d3FF0000000000000;
red.global.add.f64 [%addr], %one;
DONE:
ret;
}
";
fn launch_histc_config(n: usize) -> LaunchConfig {
launch_1d(n)
}
fn launch_histc<V>(
input: &CudaSlice<V>,
n: usize,
bins: usize,
min_val: V,
max_val: V,
device: &GpuDevice,
ptx: &'static str,
kernel_name: &'static str,
) -> GpuResult<CudaSlice<V>>
where
V: cudarc::driver::DeviceRepr + cudarc::driver::ValidAsZeroBits + Copy,
{
if input.len() < n {
return Err(GpuError::LengthMismatch {
a: input.len(),
b: n,
});
}
if n > u32::MAX as usize || bins > u32::MAX as usize {
return Err(GpuError::LengthMismatch {
a: n.max(bins),
b: u32::MAX as usize,
});
}
let stream = device.stream();
let mut out = stream.alloc_zeros::<V>(bins)?;
if n == 0 {
return Ok(out);
}
let ctx = device.context();
let f = get_or_compile(ctx, ptx, kernel_name, device.ordinal() as u32).map_err(|e| {
GpuError::PtxCompileFailed {
kernel: kernel_name,
source: e,
}
})?;
let cfg = launch_histc_config(n);
let n_u = n as u32;
let bins_u = bins as u32;
unsafe {
stream
.launch_builder(&f)
.arg(input)
.arg(&mut out)
.arg(&n_u)
.arg(&bins_u)
.arg(&min_val)
.arg(&max_val)
.launch(cfg)?;
}
Ok(out)
}
pub fn gpu_histc_f32(
input: &CudaSlice<f32>,
n: usize,
bins: usize,
min_val: f32,
max_val: f32,
device: &GpuDevice,
) -> GpuResult<CudaSlice<f32>> {
launch_histc(
input,
n,
bins,
min_val,
max_val,
device,
HISTC_F32_PTX,
"histc_f32_kernel",
)
}
pub fn gpu_histc_f64(
input: &CudaSlice<f64>,
n: usize,
bins: usize,
min_val: f64,
max_val: f64,
device: &GpuDevice,
) -> GpuResult<CudaSlice<f64>> {
launch_histc(
input,
n,
bins,
min_val,
max_val,
device,
HISTC_F64_PTX,
"histc_f64_kernel",
)
}
const MESHGRID_F32_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry meshgrid_f32_kernel(
.param .u64 in_ptr,
.param .u64 out_ptr,
.param .u32 total,
.param .u32 inner,
.param .u32 axis_len
) {
.reg .u32 %tid_r, %bid_r, %bdim_r, %t, %tot, %inr, %al, %q, %coord;
.reg .u64 %in_p, %out_p, %off, %addr;
.reg .f32 %v;
.reg .pred %p_oob;
ld.param.u64 %in_p, [in_ptr];
ld.param.u64 %out_p, [out_ptr];
ld.param.u32 %tot, [total];
ld.param.u32 %inr, [inner];
ld.param.u32 %al, [axis_len];
mov.u32 %tid_r, %tid.x;
mov.u32 %bid_r, %ctaid.x;
mov.u32 %bdim_r, %ntid.x;
mad.lo.u32 %t, %bid_r, %bdim_r, %tid_r;
setp.ge.u32 %p_oob, %t, %tot;
@%p_oob bra DONE;
// coord = (flat / inner) % axis_len
div.u32 %q, %t, %inr;
rem.u32 %coord, %q, %al;
// v = in[coord]; out[flat] = v
cvt.u64.u32 %off, %coord;
shl.b64 %off, %off, 2;
add.u64 %addr, %in_p, %off;
ld.global.f32 %v, [%addr];
cvt.u64.u32 %off, %t;
shl.b64 %off, %off, 2;
add.u64 %addr, %out_p, %off;
st.global.f32 [%addr], %v;
DONE:
ret;
}
";
const MESHGRID_F64_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry meshgrid_f64_kernel(
.param .u64 in_ptr,
.param .u64 out_ptr,
.param .u32 total,
.param .u32 inner,
.param .u32 axis_len
) {
.reg .u32 %tid_r, %bid_r, %bdim_r, %t, %tot, %inr, %al, %q, %coord;
.reg .u64 %in_p, %out_p, %off, %addr;
.reg .f64 %v;
.reg .pred %p_oob;
ld.param.u64 %in_p, [in_ptr];
ld.param.u64 %out_p, [out_ptr];
ld.param.u32 %tot, [total];
ld.param.u32 %inr, [inner];
ld.param.u32 %al, [axis_len];
mov.u32 %tid_r, %tid.x;
mov.u32 %bid_r, %ctaid.x;
mov.u32 %bdim_r, %ntid.x;
mad.lo.u32 %t, %bid_r, %bdim_r, %tid_r;
setp.ge.u32 %p_oob, %t, %tot;
@%p_oob bra DONE;
div.u32 %q, %t, %inr;
rem.u32 %coord, %q, %al;
cvt.u64.u32 %off, %coord;
shl.b64 %off, %off, 3;
add.u64 %addr, %in_p, %off;
ld.global.f64 %v, [%addr];
cvt.u64.u32 %off, %t;
shl.b64 %off, %off, 3;
add.u64 %addr, %out_p, %off;
st.global.f64 [%addr], %v;
DONE:
ret;
}
";
fn launch_meshgrid<V>(
input: &CudaSlice<V>,
total: usize,
inner: usize,
axis_len: usize,
device: &GpuDevice,
ptx: &'static str,
kernel_name: &'static str,
) -> GpuResult<CudaSlice<V>>
where
V: cudarc::driver::DeviceRepr + cudarc::driver::ValidAsZeroBits,
{
if input.len() < axis_len {
return Err(GpuError::LengthMismatch {
a: input.len(),
b: axis_len,
});
}
if total > u32::MAX as usize || inner > u32::MAX as usize || axis_len > u32::MAX as usize {
return Err(GpuError::LengthMismatch {
a: total.max(inner).max(axis_len),
b: u32::MAX as usize,
});
}
let stream = device.stream();
let mut out = stream.alloc_zeros::<V>(total)?;
if total == 0 {
return Ok(out);
}
let ctx = device.context();
let f = get_or_compile(ctx, ptx, kernel_name, device.ordinal() as u32).map_err(|e| {
GpuError::PtxCompileFailed {
kernel: kernel_name,
source: e,
}
})?;
let cfg = launch_1d(total);
let total_u = total as u32;
let inner_u = inner.max(1) as u32;
let axis_u = axis_len as u32;
unsafe {
stream
.launch_builder(&f)
.arg(input)
.arg(&mut out)
.arg(&total_u)
.arg(&inner_u)
.arg(&axis_u)
.launch(cfg)?;
}
Ok(out)
}
pub fn gpu_meshgrid_f32(
input: &CudaSlice<f32>,
total: usize,
inner: usize,
axis_len: usize,
device: &GpuDevice,
) -> GpuResult<CudaSlice<f32>> {
launch_meshgrid(
input,
total,
inner,
axis_len,
device,
MESHGRID_F32_PTX,
"meshgrid_f32_kernel",
)
}
pub fn gpu_meshgrid_f64(
input: &CudaSlice<f64>,
total: usize,
inner: usize,
axis_len: usize,
device: &GpuDevice,
) -> GpuResult<CudaSlice<f64>> {
launch_meshgrid(
input,
total,
inner,
axis_len,
device,
MESHGRID_F64_PTX,
"meshgrid_f64_kernel",
)
}
const RUN_FLAG_F32_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry run_flag_f32_kernel(
.param .u64 in_ptr,
.param .u64 flag_ptr,
.param .u32 n
) {
.reg .u32 %tid_r, %bid_r, %bdim_r, %idx, %n_r;
.reg .u64 %in_p, %flag_p, %off, %addr, %prev_addr;
.reg .f32 %cur, %prev, %one, %zero;
.reg .pred %p_oob, %p_first, %p_ne;
ld.param.u64 %in_p, [in_ptr];
ld.param.u64 %flag_p, [flag_ptr];
ld.param.u32 %n_r, [n];
mov.u32 %tid_r, %tid.x;
mov.u32 %bid_r, %ctaid.x;
mov.u32 %bdim_r, %ntid.x;
mad.lo.u32 %idx, %bid_r, %bdim_r, %tid_r;
setp.ge.u32 %p_oob, %idx, %n_r;
@%p_oob bra DONE;
mov.f32 %one, 0f3F800000;
mov.f32 %zero, 0f00000000;
// off = idx * 4 (f32 element stride)
cvt.u64.u32 %off, %idx;
shl.b64 %off, %off, 2;
add.u64 %addr, %in_p, %off;
ld.global.f32 %cur, [%addr];
// idx == 0 is always a run-start.
setp.eq.u32 %p_first, %idx, 0;
@%p_first bra WRITE_ONE;
// prev = in[idx-1]
sub.u64 %prev_addr, %addr, 4;
ld.global.f32 %prev, [%prev_addr];
// run-start iff cur != prev. setp.neu is the UNORDERED not-equal: NaN vs
// NaN -> true and NaN vs finite -> true, so every NaN starts its own run
// (matching the CPU `data[i] == data[i-1]` negation and torch). The ordered
// setp.ne returns FALSE for NaN operands and would collapse consecutive NaNs.
setp.neu.f32 %p_ne, %cur, %prev;
@%p_ne bra WRITE_ONE;
// not a run-start
add.u64 %addr, %flag_p, %off;
st.global.f32 [%addr], %zero;
bra DONE;
WRITE_ONE:
add.u64 %addr, %flag_p, %off;
st.global.f32 [%addr], %one;
DONE:
ret;
}
";
const RUN_FLAG_F64_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry run_flag_f64_kernel(
.param .u64 in_ptr,
.param .u64 flag_ptr,
.param .u32 n
) {
.reg .u32 %tid_r, %bid_r, %bdim_r, %idx, %n_r;
.reg .u64 %in_p, %flag_p, %ioff, %foff, %addr, %prev_addr;
.reg .f64 %cur, %prev;
.reg .f32 %one, %zero;
.reg .pred %p_oob, %p_first, %p_ne;
ld.param.u64 %in_p, [in_ptr];
ld.param.u64 %flag_p, [flag_ptr];
ld.param.u32 %n_r, [n];
mov.u32 %tid_r, %tid.x;
mov.u32 %bid_r, %ctaid.x;
mov.u32 %bdim_r, %ntid.x;
mad.lo.u32 %idx, %bid_r, %bdim_r, %tid_r;
setp.ge.u32 %p_oob, %idx, %n_r;
@%p_oob bra DONE;
mov.f32 %one, 0f3F800000;
mov.f32 %zero, 0f00000000;
// ioff = idx * 8 (f64 input stride); foff = idx * 4 (f32 flag stride)
cvt.u64.u32 %ioff, %idx;
shl.b64 %ioff, %ioff, 3;
cvt.u64.u32 %foff, %idx;
shl.b64 %foff, %foff, 2;
add.u64 %addr, %in_p, %ioff;
ld.global.f64 %cur, [%addr];
setp.eq.u32 %p_first, %idx, 0;
@%p_first bra WRITE_ONE;
sub.u64 %prev_addr, %addr, 8;
ld.global.f64 %prev, [%prev_addr];
// run-start iff cur != prev. setp.neu (unordered not-equal) makes every NaN
// its own run (NaN vs NaN -> true), matching the CPU path and torch; the
// ordered setp.ne returns FALSE for NaN and would collapse consecutive NaNs.
setp.neu.f64 %p_ne, %cur, %prev;
@%p_ne bra WRITE_ONE;
add.u64 %addr, %flag_p, %foff;
st.global.f32 [%addr], %zero;
bra DONE;
WRITE_ONE:
add.u64 %addr, %flag_p, %foff;
st.global.f32 [%addr], %one;
DONE:
ret;
}
";
const COMPACT_F32_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry compact_f32_kernel(
.param .u64 in_ptr,
.param .u64 incl_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %tid_r, %bid_r, %bdim_r, %idx, %n_r, %pos;
.reg .u64 %in_p, %incl_p, %out_p, %off, %addr, %prev_addr, %ooff;
.reg .f32 %cur, %prev, %incl;
.reg .pred %p_oob, %p_first, %p_ne;
ld.param.u64 %in_p, [in_ptr];
ld.param.u64 %incl_p, [incl_ptr];
ld.param.u64 %out_p, [out_ptr];
ld.param.u32 %n_r, [n];
mov.u32 %tid_r, %tid.x;
mov.u32 %bid_r, %ctaid.x;
mov.u32 %bdim_r, %ntid.x;
mad.lo.u32 %idx, %bid_r, %bdim_r, %tid_r;
setp.ge.u32 %p_oob, %idx, %n_r;
@%p_oob bra DONE;
cvt.u64.u32 %off, %idx;
shl.b64 %off, %off, 2;
add.u64 %addr, %in_p, %off;
ld.global.f32 %cur, [%addr];
// Re-derive run-start from the data (same predicate as RUN_FLAG).
setp.eq.u32 %p_first, %idx, 0;
@%p_first bra DO_STORE;
sub.u64 %prev_addr, %addr, 4;
ld.global.f32 %prev, [%prev_addr];
// setp.neu (unordered) must match RUN_FLAG exactly: NaN -> own run, else
// out_len and the scatter positions disagree.
setp.neu.f32 %p_ne, %cur, %prev;
@%p_ne bra DO_STORE;
bra DONE;
DO_STORE:
// pos = (u32)incl[idx] - 1
add.u64 %addr, %incl_p, %off;
ld.global.f32 %incl, [%addr];
cvt.rzi.u32.f32 %pos, %incl;
sub.u32 %pos, %pos, 1;
// out[pos] = cur
cvt.u64.u32 %ooff, %pos;
shl.b64 %ooff, %ooff, 2;
add.u64 %addr, %out_p, %ooff;
st.global.f32 [%addr], %cur;
DONE:
ret;
}
";
const COMPACT_F64_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry compact_f64_kernel(
.param .u64 in_ptr,
.param .u64 incl_ptr,
.param .u64 out_ptr,
.param .u32 n
) {
.reg .u32 %tid_r, %bid_r, %bdim_r, %idx, %n_r, %pos;
.reg .u64 %in_p, %incl_p, %out_p, %ioff, %foff, %addr, %prev_addr, %ooff;
.reg .f64 %cur, %prev;
.reg .f32 %incl;
.reg .pred %p_oob, %p_first, %p_ne;
ld.param.u64 %in_p, [in_ptr];
ld.param.u64 %incl_p, [incl_ptr];
ld.param.u64 %out_p, [out_ptr];
ld.param.u32 %n_r, [n];
mov.u32 %tid_r, %tid.x;
mov.u32 %bid_r, %ctaid.x;
mov.u32 %bdim_r, %ntid.x;
mad.lo.u32 %idx, %bid_r, %bdim_r, %tid_r;
setp.ge.u32 %p_oob, %idx, %n_r;
@%p_oob bra DONE;
cvt.u64.u32 %ioff, %idx;
shl.b64 %ioff, %ioff, 3;
add.u64 %addr, %in_p, %ioff;
ld.global.f64 %cur, [%addr];
setp.eq.u32 %p_first, %idx, 0;
@%p_first bra DO_STORE;
sub.u64 %prev_addr, %addr, 8;
ld.global.f64 %prev, [%prev_addr];
// setp.neu (unordered) must match RUN_FLAG_F64 exactly: NaN -> own run, else
// out_len and the scatter positions disagree.
setp.neu.f64 %p_ne, %cur, %prev;
@%p_ne bra DO_STORE;
bra DONE;
DO_STORE:
cvt.u64.u32 %foff, %idx;
shl.b64 %foff, %foff, 2;
add.u64 %addr, %incl_p, %foff;
ld.global.f32 %incl, [%addr];
cvt.rzi.u32.f32 %pos, %incl;
sub.u32 %pos, %pos, 1;
cvt.u64.u32 %ooff, %pos;
shl.b64 %ooff, %ooff, 3;
add.u64 %addr, %out_p, %ooff;
st.global.f64 [%addr], %cur;
DONE:
ret;
}
";
#[cfg(feature = "cuda")]
fn run_flags_and_scan(
in_slice: &CudaSlice<impl cudarc::driver::DeviceRepr>,
n: usize,
device: &GpuDevice,
flag_ptx: &'static str,
flag_kernel: &'static str,
) -> GpuResult<CudaBuffer<f32>> {
let stream = device.stream();
let ctx = device.context();
let mut flags = alloc_zeros_f32(n, device)?;
let n_u = n as u32;
let f = get_or_compile(ctx, flag_ptx, flag_kernel, device.ordinal() as u32).map_err(|e| {
GpuError::PtxCompileFailed {
kernel: flag_kernel,
source: e,
}
})?;
let block: u32 = 256;
let grid = (n as u32).div_ceil(block).max(1);
let cfg = LaunchConfig {
grid_dim: (grid, 1, 1),
block_dim: (block, 1, 1),
shared_mem_bytes: 0,
};
unsafe {
stream
.launch_builder(&f)
.arg(in_slice)
.arg(flags.inner_mut())
.arg(&n_u)
.launch(cfg)?;
}
let incl = gpu_cumsum(&flags, 1, n, 1, device)?;
Ok(incl)
}
fn decode_runs(incl_host: &[f32]) -> (Vec<usize>, Vec<usize>, usize) {
let n = incl_host.len();
if n == 0 {
return (vec![], vec![], 0);
}
let out_len = incl_host[n - 1] as usize;
let mut inverse = vec![0usize; n];
let mut counts = vec![0usize; out_len];
for (i, &incl) in incl_host.iter().enumerate() {
let inv = (incl as usize) - 1;
inverse[i] = inv;
counts[inv] += 1;
}
(inverse, counts, out_len)
}
#[cfg(feature = "cuda")]
pub fn gpu_unique_consecutive_f32(
input: &CudaBuffer<f32>,
n: usize,
device: &GpuDevice,
) -> GpuResult<(CudaBuffer<f32>, Vec<usize>, Vec<usize>)> {
if n == 0 {
return Ok((alloc_zeros_f32(0, device)?, vec![], vec![]));
}
if n > u32::MAX as usize {
return Err(GpuError::LengthMismatch {
a: n,
b: u32::MAX as usize,
});
}
let incl = run_flags_and_scan(
input.inner(),
n,
device,
RUN_FLAG_F32_PTX,
"run_flag_f32_kernel",
)?;
let incl_host = gpu_to_cpu(&incl, device)?;
let (inverse, counts, out_len) = decode_runs(&incl_host);
let mut out = alloc_zeros_f32(out_len, device)?;
launch_compact_f32(input, &incl, &mut out, n, device)?;
Ok((out, inverse, counts))
}
#[cfg(feature = "cuda")]
pub fn gpu_unique_consecutive_f64(
input: &CudaBuffer<f64>,
n: usize,
device: &GpuDevice,
) -> GpuResult<(CudaBuffer<f64>, Vec<usize>, Vec<usize>)> {
if n == 0 {
return Ok((alloc_zeros_f64(0, device)?, vec![], vec![]));
}
if n > u32::MAX as usize {
return Err(GpuError::LengthMismatch {
a: n,
b: u32::MAX as usize,
});
}
let incl = run_flags_and_scan(
input.inner(),
n,
device,
RUN_FLAG_F64_PTX,
"run_flag_f64_kernel",
)?;
let incl_host = gpu_to_cpu(&incl, device)?;
let (inverse, counts, out_len) = decode_runs(&incl_host);
let mut out = alloc_zeros_f64(out_len, device)?;
launch_compact_f64(input, &incl, &mut out, n, device)?;
Ok((out, inverse, counts))
}
#[cfg(feature = "cuda")]
fn launch_compact_f32(
input: &CudaBuffer<f32>,
incl: &CudaBuffer<f32>,
out: &mut CudaBuffer<f32>,
n: usize,
device: &GpuDevice,
) -> GpuResult<()> {
let stream = device.stream();
let ctx = device.context();
let f = get_or_compile(
ctx,
COMPACT_F32_PTX,
"compact_f32_kernel",
device.ordinal() as u32,
)
.map_err(|e| GpuError::PtxCompileFailed {
kernel: "compact_f32_kernel",
source: e,
})?;
let block: u32 = 256;
let grid = (n as u32).div_ceil(block).max(1);
let cfg = LaunchConfig {
grid_dim: (grid, 1, 1),
block_dim: (block, 1, 1),
shared_mem_bytes: 0,
};
let n_u = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(incl.inner())
.arg(out.inner_mut())
.arg(&n_u)
.launch(cfg)?;
}
Ok(())
}
#[cfg(feature = "cuda")]
fn launch_compact_f64(
input: &CudaBuffer<f64>,
incl: &CudaBuffer<f32>,
out: &mut CudaBuffer<f64>,
n: usize,
device: &GpuDevice,
) -> GpuResult<()> {
let stream = device.stream();
let ctx = device.context();
let f = get_or_compile(
ctx,
COMPACT_F64_PTX,
"compact_f64_kernel",
device.ordinal() as u32,
)
.map_err(|e| GpuError::PtxCompileFailed {
kernel: "compact_f64_kernel",
source: e,
})?;
let block: u32 = 256;
let grid = (n as u32).div_ceil(block).max(1);
let cfg = LaunchConfig {
grid_dim: (grid, 1, 1),
block_dim: (block, 1, 1),
shared_mem_bytes: 0,
};
let n_u = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(incl.inner())
.arg(out.inner_mut())
.arg(&n_u)
.launch(cfg)?;
}
Ok(())
}
const UNIQUE_INIT_F32_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry unique_init_f32_kernel(
.param .u64 in_ptr,
.param .u64 key_ptr,
.param .u64 idx_ptr,
.param .u32 n,
.param .u32 npad
) {
.reg .u32 %tid_r, %bid_r, %bdim_r, %i, %n_r, %npad_r, %pad_idx;
.reg .u64 %in_p, %key_p, %idx_p, %koff, %ioff, %addr;
.reg .f32 %v, %inf;
.reg .pred %p_oob, %p_real;
ld.param.u64 %in_p, [in_ptr];
ld.param.u64 %key_p, [key_ptr];
ld.param.u64 %idx_p, [idx_ptr];
ld.param.u32 %n_r, [n];
ld.param.u32 %npad_r, [npad];
mov.u32 %tid_r, %tid.x;
mov.u32 %bid_r, %ctaid.x;
mov.u32 %bdim_r, %ntid.x;
mad.lo.u32 %i, %bid_r, %bdim_r, %tid_r;
setp.ge.u32 %p_oob, %i, %npad_r;
@%p_oob bra DONE;
// koff = i * 4 (f32 key stride); ioff = i * 4 (i32 idx stride)
cvt.u64.u32 %koff, %i;
shl.b64 %koff, %koff, 2;
cvt.u64.u32 %ioff, %i;
shl.b64 %ioff, %ioff, 2;
mov.u32 %pad_idx, 2147483647; // i32::MAX
mov.f32 %inf, 0f7F800000; // +INF
setp.lt.u32 %p_real, %i, %n_r;
@!%p_real bra WRITE_PAD;
// real: key[i] = in[i], idx[i] = i
add.u64 %addr, %in_p, %koff;
ld.global.f32 %v, [%addr];
add.u64 %addr, %key_p, %koff;
st.global.f32 [%addr], %v;
add.u64 %addr, %idx_p, %ioff;
st.global.u32 [%addr], %i;
bra DONE;
WRITE_PAD:
add.u64 %addr, %key_p, %koff;
st.global.f32 [%addr], %inf;
add.u64 %addr, %idx_p, %ioff;
st.global.u32 [%addr], %pad_idx;
DONE:
ret;
}
";
const UNIQUE_INIT_F64_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry unique_init_f64_kernel(
.param .u64 in_ptr,
.param .u64 key_ptr,
.param .u64 idx_ptr,
.param .u32 n,
.param .u32 npad
) {
.reg .u32 %tid_r, %bid_r, %bdim_r, %i, %n_r, %npad_r, %pad_idx;
.reg .u64 %in_p, %key_p, %idx_p, %koff, %ioff, %addr;
.reg .f64 %v, %inf;
.reg .pred %p_oob, %p_real;
ld.param.u64 %in_p, [in_ptr];
ld.param.u64 %key_p, [key_ptr];
ld.param.u64 %idx_p, [idx_ptr];
ld.param.u32 %n_r, [n];
ld.param.u32 %npad_r, [npad];
mov.u32 %tid_r, %tid.x;
mov.u32 %bid_r, %ctaid.x;
mov.u32 %bdim_r, %ntid.x;
mad.lo.u32 %i, %bid_r, %bdim_r, %tid_r;
setp.ge.u32 %p_oob, %i, %npad_r;
@%p_oob bra DONE;
// koff = i * 8 (f64 key stride); ioff = i * 4 (i32 idx stride)
cvt.u64.u32 %koff, %i;
shl.b64 %koff, %koff, 3;
cvt.u64.u32 %ioff, %i;
shl.b64 %ioff, %ioff, 2;
mov.u32 %pad_idx, 2147483647;
mov.f64 %inf, 0d7FF0000000000000; // +INF (f64)
setp.lt.u32 %p_real, %i, %n_r;
@!%p_real bra WRITE_PAD;
add.u64 %addr, %in_p, %koff;
ld.global.f64 %v, [%addr];
add.u64 %addr, %key_p, %koff;
st.global.f64 [%addr], %v;
add.u64 %addr, %idx_p, %ioff;
st.global.u32 [%addr], %i;
bra DONE;
WRITE_PAD:
add.u64 %addr, %key_p, %koff;
st.global.f64 [%addr], %inf;
add.u64 %addr, %idx_p, %ioff;
st.global.u32 [%addr], %pad_idx;
DONE:
ret;
}
";
macro_rules! unique_bitonic_ptx {
($name:literal, $tyld:literal, $kbytes:literal) => {
concat!(
".version 7.0\n",
".target sm_52\n",
".address_size 64\n",
"\n",
".visible .entry ",
$name,
"(\n",
" .param .u64 key_ptr,\n",
" .param .u64 idx_ptr,\n",
" .param .u32 npad,\n",
" .param .u32 j,\n",
" .param .u32 k\n",
") {\n",
" .reg .u32 %tid_r, %bid_r, %bdim_r, %i, %ixj, %npad_r, %j_r, %k_r, %t1, %t2, %g;\n",
" .reg .u32 %ia, %ib;\n",
" .reg .u64 %key_p, %idx_p, %koffa, %koffb, %ioffa, %ioffb, %addr;\n",
" .reg .",
$tyld,
" %a, %b;\n",
" .reg .pred %p_oob, %p_partner, %p_asc, %p_pada, %p_padb;\n",
" .reg .pred %p_nana, %p_nanb, %p_gt, %p_swap, %p_tmp;\n",
"\n",
" ld.param.u64 %key_p, [key_ptr];\n",
" ld.param.u64 %idx_p, [idx_ptr];\n",
" ld.param.u32 %npad_r, [npad];\n",
" ld.param.u32 %j_r, [j];\n",
" ld.param.u32 %k_r, [k];\n",
"\n",
" mov.u32 %tid_r, %tid.x;\n",
" mov.u32 %bid_r, %ctaid.x;\n",
" mov.u32 %bdim_r, %ntid.x;\n",
" mad.lo.u32 %i, %bid_r, %bdim_r, %tid_r;\n",
"\n",
" setp.ge.u32 %p_oob, %i, %npad_r;\n",
" @%p_oob bra DONE;\n",
"\n",
" // ixj = i ^ j; only the LOWER thread of the pair acts (ixj > i).\n",
" xor.b32 %ixj, %i, %j_r;\n",
" setp.le.u32 %p_partner, %ixj, %i;\n",
" @%p_partner bra DONE;\n",
"\n",
" // ascending block iff (i & k) == 0\n",
" and.b32 %t1, %i, %k_r;\n",
" setp.eq.u32 %p_asc, %t1, 0;\n",
"\n",
" // load (key,idx) at i and ixj\n",
" cvt.u64.u32 %koffa, %i;\n",
" shl.b64 %koffa, %koffa, ",
$kbytes,
";\n",
" cvt.u64.u32 %koffb, %ixj;\n",
" shl.b64 %koffb, %koffb, ",
$kbytes,
";\n",
" cvt.u64.u32 %ioffa, %i;\n",
" shl.b64 %ioffa, %ioffa, 2;\n",
" cvt.u64.u32 %ioffb, %ixj;\n",
" shl.b64 %ioffb, %ioffb, 2;\n",
" add.u64 %addr, %key_p, %koffa;\n",
" ld.global.",
$tyld,
" %a, [%addr];\n",
" add.u64 %addr, %key_p, %koffb;\n",
" ld.global.",
$tyld,
" %b, [%addr];\n",
" add.u64 %addr, %idx_p, %ioffa;\n",
" ld.global.u32 %ia, [%addr];\n",
" add.u64 %addr, %idx_p, %ioffb;\n",
" ld.global.u32 %ib, [%addr];\n",
"\n",
" // --- total-order greater(a,b): does a belong AFTER b? ---\n",
" // Computed as a u32 flag %g (0/1); predicate arithmetic is avoided (PTX\n",
" // has no setp/mov on .pred operands) by materialising each sub-predicate\n",
" // to a u32 via selp and branching on setp.*.u32.\n",
" // pad status (idx == i32::MAX), as u32 flags.\n",
" setp.eq.u32 %p_pada, %ia, 2147483647;\n",
" selp.u32 %t1, 1, 0, %p_pada;\n",
" setp.eq.u32 %p_padb, %ib, 2147483647;\n",
" selp.u32 %t2, 1, 0, %p_padb;\n",
" // if pad_a != pad_b -> greater = pad_a; resolve at the tail.\n",
" setp.ne.u32 %p_tmp, %t1, %t2;\n",
" @%p_tmp bra PAD_DECIDE;\n",
"\n",
" // same pad status: NaN-aware value compare. setp.neu.<f> (UNORDERED\n",
" // not-equal) self-compare is true ONLY for NaN (a != a iff a is NaN); the\n",
" // ORDERED setp.ne returns FALSE for NaN and would mis-rank it. Materialise\n",
" // each to a u32 flag.\n",
" setp.neu.",
$tyld,
" %p_nana, %a, %a;\n",
" selp.u32 %t1, 1, 0, %p_nana;\n",
" setp.neu.",
$tyld,
" %p_nanb, %b, %b;\n",
" selp.u32 %t2, 1, 0, %p_nanb;\n",
" setp.ne.u32 %p_tmp, %t1, %t2;\n",
" @%p_tmp bra NAN_DECIDE;\n",
" // equal NaN status. If both NaN they are 'equal' as values -> break the tie\n",
" // by ASCENDING original index (greater = ia > ib) so distinct NaN entries\n",
" // sort by original position, matching torch's radix-stable NaN order\n",
" // (verified live: unique([nan,1,nan,2,nan]).inverse = [2,0,3,1,4]).\n",
" @%p_nana bra IDX_DECIDE; // both NaN -> tie-break by index\n",
" // both finite: ordered compare; on an exact value tie, also break by index\n",
" // so equal-value runs are stable (uid is identical either way, but this\n",
" // keeps the sorted permutation deterministic).\n",
" setp.gt.",
$tyld,
" %p_gt, %a, %b;\n",
" @%p_gt bra SET_TRUE;\n",
" setp.lt.",
$tyld,
" %p_gt, %a, %b;\n",
" @%p_gt bra SET_FALSE;\n",
" // a == b (finite): tie-break by index.\n",
" bra IDX_DECIDE;\n",
"\n",
"PAD_DECIDE:\n",
" selp.u32 %g, 1, 0, %p_pada; // the pad is the greater one\n",
" bra HAVE_GREATER;\n",
"\n",
"NAN_DECIDE:\n",
" selp.u32 %g, 1, 0, %p_nana; // the NaN is the greater one\n",
" bra HAVE_GREATER;\n",
"\n",
"IDX_DECIDE:\n",
" setp.gt.u32 %p_gt, %ia, %ib; // greater iff higher original index\n",
" selp.u32 %g, 1, 0, %p_gt;\n",
" bra HAVE_GREATER;\n",
"\n",
"SET_TRUE:\n",
" mov.u32 %g, 1;\n",
" bra HAVE_GREATER;\n",
"\n",
"SET_FALSE:\n",
" mov.u32 %g, 0;\n",
"\n",
"HAVE_GREATER:\n",
" // ascending flag as u32; swap iff greater == ascending.\n",
" selp.u32 %t1, 1, 0, %p_asc;\n",
" setp.ne.u32 %p_swap, %g, %t1; // p_swap true when they DIFFER\n",
" @%p_swap bra DONE; // differ -> no swap\n",
"\n",
" // exchange key and idx between i and ixj\n",
" add.u64 %addr, %key_p, %koffa;\n",
" st.global.",
$tyld,
" [%addr], %b;\n",
" add.u64 %addr, %key_p, %koffb;\n",
" st.global.",
$tyld,
" [%addr], %a;\n",
" add.u64 %addr, %idx_p, %ioffa;\n",
" st.global.u32 [%addr], %ib;\n",
" add.u64 %addr, %idx_p, %ioffb;\n",
" st.global.u32 [%addr], %ia;\n",
"\n",
"DONE:\n",
" ret;\n",
"}\n"
)
};
}
const UNIQUE_BITONIC_F32_PTX: &str = unique_bitonic_ptx!("unique_bitonic_f32_kernel", "f32", "2");
const UNIQUE_BITONIC_F64_PTX: &str = unique_bitonic_ptx!("unique_bitonic_f64_kernel", "f64", "3");
#[cfg(feature = "cuda")]
fn next_pow2(n: usize) -> usize {
if n <= 1 {
return n.max(1);
}
let mut p = 1usize;
while p < n {
p <<= 1;
}
p
}
#[cfg(feature = "cuda")]
fn bitonic_sort_by_key(
key: &mut CudaSlice<impl cudarc::driver::DeviceRepr>,
idx: &mut CudaSlice<i32>,
npad: usize,
device: &GpuDevice,
bitonic_ptx: &'static str,
bitonic_kernel: &'static str,
) -> GpuResult<()> {
if npad <= 1 {
return Ok(());
}
let stream = device.stream();
let ctx = device.context();
let f =
get_or_compile(ctx, bitonic_ptx, bitonic_kernel, device.ordinal() as u32).map_err(|e| {
GpuError::PtxCompileFailed {
kernel: bitonic_kernel,
source: e,
}
})?;
let block: u32 = 256;
let grid = (npad as u32).div_ceil(block).max(1);
let cfg = LaunchConfig {
grid_dim: (grid, 1, 1),
block_dim: (block, 1, 1),
shared_mem_bytes: 0,
};
let npad_u = npad as u32;
let mut k = 2usize;
while k <= npad {
let mut j = k >> 1;
while j >= 1 {
let j_u = j as u32;
let k_u = k as u32;
unsafe {
stream
.launch_builder(&f)
.arg(&*key)
.arg(&*idx)
.arg(&npad_u)
.arg(&j_u)
.arg(&k_u)
.launch(cfg)?;
}
j >>= 1;
}
k <<= 1;
}
Ok(())
}
#[cfg(feature = "cuda")]
fn launch_unique_init_f32(
input: &CudaBuffer<f32>,
key: &mut CudaBuffer<f32>,
idx: &mut CudaSlice<i32>,
n: usize,
npad: usize,
device: &GpuDevice,
) -> GpuResult<()> {
let stream = device.stream();
let ctx = device.context();
let f = get_or_compile(
ctx,
UNIQUE_INIT_F32_PTX,
"unique_init_f32_kernel",
device.ordinal() as u32,
)
.map_err(|e| GpuError::PtxCompileFailed {
kernel: "unique_init_f32_kernel",
source: e,
})?;
let block: u32 = 256;
let grid = (npad as u32).div_ceil(block).max(1);
let cfg = LaunchConfig {
grid_dim: (grid, 1, 1),
block_dim: (block, 1, 1),
shared_mem_bytes: 0,
};
let n_u = n as u32;
let npad_u = npad as u32;
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(key.inner_mut())
.arg(&*idx)
.arg(&n_u)
.arg(&npad_u)
.launch(cfg)?;
}
Ok(())
}
#[cfg(feature = "cuda")]
fn launch_unique_init_f64(
input: &CudaBuffer<f64>,
key: &mut CudaBuffer<f64>,
idx: &mut CudaSlice<i32>,
n: usize,
npad: usize,
device: &GpuDevice,
) -> GpuResult<()> {
let stream = device.stream();
let ctx = device.context();
let f = get_or_compile(
ctx,
UNIQUE_INIT_F64_PTX,
"unique_init_f64_kernel",
device.ordinal() as u32,
)
.map_err(|e| GpuError::PtxCompileFailed {
kernel: "unique_init_f64_kernel",
source: e,
})?;
let block: u32 = 256;
let grid = (npad as u32).div_ceil(block).max(1);
let cfg = LaunchConfig {
grid_dim: (grid, 1, 1),
block_dim: (block, 1, 1),
shared_mem_bytes: 0,
};
let n_u = n as u32;
let npad_u = npad as u32;
unsafe {
stream
.launch_builder(&f)
.arg(input.inner())
.arg(key.inner_mut())
.arg(&*idx)
.arg(&n_u)
.arg(&npad_u)
.launch(cfg)?;
}
Ok(())
}
fn decode_unique(incl_host: &[f32], sorted_idx: &[i32]) -> (Vec<usize>, Vec<usize>, usize) {
let n = incl_host.len();
if n == 0 {
return (vec![], vec![], 0);
}
let out_len = incl_host[n - 1] as usize;
let mut inverse = vec![0usize; n];
let mut counts = vec![0usize; out_len];
for i in 0..n {
let uid = (incl_host[i] as usize) - 1;
let orig = sorted_idx[i] as usize;
inverse[orig] = uid;
counts[uid] += 1;
}
(inverse, counts, out_len)
}
#[cfg(feature = "cuda")]
pub fn gpu_unique_f32(
input: &CudaBuffer<f32>,
n: usize,
device: &GpuDevice,
) -> GpuResult<(CudaBuffer<f32>, Vec<usize>, Vec<usize>)> {
if n == 0 {
return Ok((alloc_zeros_f32(0, device)?, vec![], vec![]));
}
if n > (i32::MAX as usize) {
return Err(GpuError::LengthMismatch {
a: n,
b: i32::MAX as usize,
});
}
let npad = next_pow2(n);
let mut key = alloc_zeros_f32(npad, device)?;
let mut idx: CudaSlice<i32> = device.stream().alloc_zeros::<i32>(npad)?;
launch_unique_init_f32(input, &mut key, &mut idx, n, npad, device)?;
bitonic_sort_by_key(
key.inner_mut(),
&mut idx,
npad,
device,
UNIQUE_BITONIC_F32_PTX,
"unique_bitonic_f32_kernel",
)?;
let incl = run_flags_and_scan(
key.inner(),
n,
device,
RUN_FLAG_F32_PTX,
"run_flag_f32_kernel",
)?;
let incl_host = gpu_to_cpu(&incl, device)?;
let mut sorted_idx_host = device.stream().clone_dtoh(&idx)?;
sorted_idx_host.truncate(n);
let (inverse, counts, out_len) = decode_unique(&incl_host, &sorted_idx_host);
let mut out = alloc_zeros_f32(out_len, device)?;
launch_compact_f32(&key, &incl, &mut out, n, device)?;
Ok((out, inverse, counts))
}
#[cfg(feature = "cuda")]
pub fn gpu_unique_f64(
input: &CudaBuffer<f64>,
n: usize,
device: &GpuDevice,
) -> GpuResult<(CudaBuffer<f64>, Vec<usize>, Vec<usize>)> {
if n == 0 {
return Ok((alloc_zeros_f64(0, device)?, vec![], vec![]));
}
if n > (i32::MAX as usize) {
return Err(GpuError::LengthMismatch {
a: n,
b: i32::MAX as usize,
});
}
let npad = next_pow2(n);
let mut key = alloc_zeros_f64(npad, device)?;
let mut idx: CudaSlice<i32> = device.stream().alloc_zeros::<i32>(npad)?;
launch_unique_init_f64(input, &mut key, &mut idx, n, npad, device)?;
bitonic_sort_by_key(
key.inner_mut(),
&mut idx,
npad,
device,
UNIQUE_BITONIC_F64_PTX,
"unique_bitonic_f64_kernel",
)?;
let incl = run_flags_and_scan(
key.inner(),
n,
device,
RUN_FLAG_F64_PTX,
"run_flag_f64_kernel",
)?;
let incl_host = gpu_to_cpu(&incl, device)?;
let mut sorted_idx_host = device.stream().clone_dtoh(&idx)?;
sorted_idx_host.truncate(n);
let (inverse, counts, out_len) = decode_unique(&incl_host, &sorted_idx_host);
let mut out = alloc_zeros_f64(out_len, device)?;
launch_compact_f64(&key, &incl, &mut out, n, device)?;
Ok((out, inverse, counts))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::transfer::cpu_to_gpu;
fn read_i64(slice: &CudaSlice<i64>, device: &GpuDevice) -> Vec<i64> {
let n = slice.len();
let mut v = device.stream().clone_dtoh(slice).unwrap();
v.truncate(n);
v
}
fn cpu_searchsorted_ref(bounds: &[f64], vals: &[f64], right: bool) -> Vec<i64> {
vals.iter()
.map(|&v| {
if right {
bounds.partition_point(|&b| b <= v) as i64
} else {
bounds.partition_point(|&b| b < v) as i64
}
})
.collect()
}
#[test]
fn searchsorted_f32_left_and_right_match_cpu() {
let device = match GpuDevice::new(0) {
Ok(d) => d,
Err(_) => return,
};
let bounds = [1.0f32, 3.0, 5.0, 7.0];
let vals = [0.0f32, 2.0, 3.0, 6.0, 8.0];
let bg = cpu_to_gpu(&bounds, &device).unwrap();
let vg = cpu_to_gpu(&vals, &device).unwrap();
let og = gpu_searchsorted_f32(
vg.inner(),
bg.inner(),
vals.len(),
bounds.len(),
true,
&device,
)
.unwrap();
assert_eq!(og.len(), vals.len());
let got = read_i64(&og, &device);
let bounds64: Vec<f64> = bounds.iter().map(|&x| x as f64).collect();
let vals64: Vec<f64> = vals.iter().map(|&x| x as f64).collect();
assert_eq!(got, cpu_searchsorted_ref(&bounds64, &vals64, true));
assert_eq!(got, vec![0, 1, 2, 3, 4]);
let og2 = gpu_searchsorted_f32(
vg.inner(),
bg.inner(),
vals.len(),
bounds.len(),
false,
&device,
)
.unwrap();
let got2 = read_i64(&og2, &device);
assert_eq!(got2, cpu_searchsorted_ref(&bounds64, &vals64, false));
}
#[test]
fn searchsorted_f32_boundary_tie_left_vs_right() {
let device = match GpuDevice::new(0) {
Ok(d) => d,
Err(_) => return,
};
let bounds = [1.0f32, 3.0, 5.0, 7.0];
let vals = [1.0f32, 3.0, 5.0, 7.0];
let bg = cpu_to_gpu(&bounds, &device).unwrap();
let vg = cpu_to_gpu(&vals, &device).unwrap();
let left = gpu_searchsorted_f32(
vg.inner(),
bg.inner(),
vals.len(),
bounds.len(),
false,
&device,
)
.unwrap();
let left_h = read_i64(&left, &device);
assert_eq!(left_h, vec![0, 1, 2, 3]);
let right = gpu_searchsorted_f32(
vg.inner(),
bg.inner(),
vals.len(),
bounds.len(),
true,
&device,
)
.unwrap();
let right_h = read_i64(&right, &device);
assert_eq!(right_h, vec![1, 2, 3, 4]); }
#[test]
fn searchsorted_f32_empty_boundaries_all_zero() {
let device = match GpuDevice::new(0) {
Ok(d) => d,
Err(_) => return,
};
let bounds: [f32; 0] = [];
let vals = [1.0f32, 2.0];
let bg = cpu_to_gpu(&bounds, &device).unwrap();
let vg = cpu_to_gpu(&vals, &device).unwrap();
let og =
gpu_searchsorted_f32(vg.inner(), bg.inner(), vals.len(), 0, true, &device).unwrap();
let got = read_i64(&og, &device);
assert_eq!(got, vec![0, 0]);
}
#[test]
fn searchsorted_f64_matches_cpu() {
let device = match GpuDevice::new(0) {
Ok(d) => d,
Err(_) => return,
};
let bounds = [-2.5f64, 0.0, 0.0, 4.25, 9.0];
let vals = [-3.0f64, -2.5, 0.0, 1.0, 9.0, 100.0];
let bg = cpu_to_gpu(&bounds, &device).unwrap();
let vg = cpu_to_gpu(&vals, &device).unwrap();
for right in [false, true] {
let og = gpu_searchsorted_f64(
vg.inner(),
bg.inner(),
vals.len(),
bounds.len(),
right,
&device,
)
.unwrap();
let got = read_i64(&og, &device);
assert_eq!(got, cpu_searchsorted_ref(&bounds, &vals, right));
}
}
fn read_f32(slice: &CudaSlice<f32>, device: &GpuDevice) -> Vec<f32> {
let n = slice.len();
let mut v = device.stream().clone_dtoh(slice).unwrap();
v.truncate(n);
v
}
fn read_f64(slice: &CudaSlice<f64>, device: &GpuDevice) -> Vec<f64> {
let n = slice.len();
let mut v = device.stream().clone_dtoh(slice).unwrap();
v.truncate(n);
v
}
fn cpu_topk_ref(
data: &[f64],
outer: usize,
dim: usize,
k: usize,
largest: bool,
) -> (Vec<f64>, Vec<i64>) {
let mut vals = Vec::with_capacity(outer * k);
let mut idxs = Vec::with_capacity(outer * k);
for o in 0..outer {
let slice = &data[o * dim..(o + 1) * dim];
let mut idx: Vec<usize> = (0..dim).collect();
if largest {
idx.sort_by(|&a, &b| slice[b].partial_cmp(&slice[a]).unwrap());
} else {
idx.sort_by(|&a, &b| slice[a].partial_cmp(&slice[b]).unwrap());
}
for &i in &idx[..k] {
vals.push(slice[i]);
idxs.push(i as i64);
}
}
(vals, idxs)
}
#[test]
fn topk_f32_largest_matches_cpu_ref() {
let device = match GpuDevice::new(0) {
Ok(d) => d,
Err(_) => return,
};
let data = [3.0f32, 1.0, 4.0, 1.0, 5.0, 9.0];
let g = cpu_to_gpu(&data, &device).unwrap();
let (vals, idx) = gpu_topk_f32(g.inner(), 1, 6, 3, true, &device).unwrap();
assert_eq!(vals.len(), 3);
assert_eq!(idx.len(), 3);
let gv = read_f32(&vals, &device);
let gi = read_i64(&idx, &device);
assert_eq!(gv, vec![9.0, 5.0, 4.0]);
assert_eq!(gi, vec![5, 4, 2]);
let data64: Vec<f64> = data.iter().map(|&x| x as f64).collect();
let (rv, ri) = cpu_topk_ref(&data64, 1, 6, 3, true);
assert_eq!(gv.iter().map(|&x| x as f64).collect::<Vec<_>>(), rv);
assert_eq!(gi, ri);
}
#[test]
fn topk_f32_smallest_matches_cpu_ref() {
let device = match GpuDevice::new(0) {
Ok(d) => d,
Err(_) => return,
};
let data = [3.0f32, 1.0, 4.0, 1.0, 5.0];
let g = cpu_to_gpu(&data, &device).unwrap();
let (vals, idx) = gpu_topk_f32(g.inner(), 1, 5, 2, false, &device).unwrap();
let gv = read_f32(&vals, &device);
let gi = read_i64(&idx, &device);
assert_eq!(gv, vec![1.0, 1.0]);
assert_eq!(gi, vec![1, 3]);
}
#[test]
fn topk_f32_ties_ascending_index() {
let device = match GpuDevice::new(0) {
Ok(d) => d,
Err(_) => return,
};
let data = [2.0f32, 2.0, 2.0, 2.0, 1.0];
let g = cpu_to_gpu(&data, &device).unwrap();
let (vals, idx) = gpu_topk_f32(g.inner(), 1, 5, 3, true, &device).unwrap();
let gv = read_f32(&vals, &device);
let gi = read_i64(&idx, &device);
assert_eq!(gv, vec![2.0, 2.0, 2.0]);
assert_eq!(gi, vec![0, 1, 2]); let data64: Vec<f64> = data.iter().map(|&x| x as f64).collect();
let (_, ri) = cpu_topk_ref(&data64, 1, 5, 3, true);
assert_eq!(gi, ri);
}
#[test]
fn topk_f32_multi_row() {
let device = match GpuDevice::new(0) {
Ok(d) => d,
Err(_) => return,
};
let data = [1.0f32, 5.0, 3.0, 2.0, 8.0, 0.0, 7.0, 6.0];
let g = cpu_to_gpu(&data, &device).unwrap();
let (vals, idx) = gpu_topk_f32(g.inner(), 2, 4, 2, true, &device).unwrap();
let gv = read_f32(&vals, &device);
let gi = read_i64(&idx, &device);
assert_eq!(gv, vec![5.0, 3.0, 8.0, 7.0]);
assert_eq!(gi, vec![1, 2, 0, 2]);
let data64: Vec<f64> = data.iter().map(|&x| x as f64).collect();
let (rv, ri) = cpu_topk_ref(&data64, 2, 4, 2, true);
assert_eq!(gv.iter().map(|&x| x as f64).collect::<Vec<_>>(), rv);
assert_eq!(gi, ri);
}
#[test]
fn topk_f32_k_equals_dim() {
let device = match GpuDevice::new(0) {
Ok(d) => d,
Err(_) => return,
};
let data = [3.0f32, 1.0, 2.0];
let g = cpu_to_gpu(&data, &device).unwrap();
let (vals, idx) = gpu_topk_f32(g.inner(), 1, 3, 3, true, &device).unwrap();
let gv = read_f32(&vals, &device);
let gi = read_i64(&idx, &device);
assert_eq!(gv, vec![3.0, 2.0, 1.0]);
assert_eq!(gi, vec![0, 2, 1]);
}
#[test]
fn topk_f64_matches_cpu_ref() {
let device = match GpuDevice::new(0) {
Ok(d) => d,
Err(_) => return,
};
let data = [3.0f64, 1.0, 4.0, 1.5, 5.0, 9.0, 2.0, 6.0];
let g = cpu_to_gpu(&data, &device).unwrap();
for largest in [true, false] {
let (vals, idx) = gpu_topk_f64(g.inner(), 1, 8, 4, largest, &device).unwrap();
let gv = read_f64(&vals, &device);
let gi = read_i64(&idx, &device);
let (rv, ri) = cpu_topk_ref(&data, 1, 8, 4, largest);
assert_eq!(gv, rv);
assert_eq!(gi, ri);
}
}
fn cpu_histc_ref(data: &[f64], bins: usize, min: f64, max: f64) -> Vec<f64> {
let mut counts = vec![0.0f64; bins];
let range = max - min;
for &v in data {
if !(v >= min && v <= max) {
continue; }
let mut bin = ((v - min) * bins as f64 / range) as i64;
if bin == bins as i64 {
bin -= 1;
}
counts[bin as usize] += 1.0;
}
counts
}
#[test]
fn histc_f32_matches_torch_bins() {
let device = match GpuDevice::new(0) {
Ok(d) => d,
Err(_) => return,
};
let data = [0.0f32, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
let g = cpu_to_gpu(&data, &device).unwrap();
let out = gpu_histc_f32(g.inner(), data.len(), 5, 0.0, 10.0, &device).unwrap();
assert_eq!(out.len(), 5);
let got = read_f32(&out, &device);
assert_eq!(got, vec![2.0, 2.0, 2.0, 2.0, 3.0]);
let data64: Vec<f64> = data.iter().map(|&x| x as f64).collect();
let want = cpu_histc_ref(&data64, 5, 0.0, 10.0);
assert_eq!(got.iter().map(|&x| x as f64).collect::<Vec<_>>(), want);
}
#[test]
fn histc_f32_skips_out_of_range_and_nan() {
let device = match GpuDevice::new(0) {
Ok(d) => d,
Err(_) => return,
};
let data = [-1.0f32, 0.5, 1.5, 2.5, 4.0, 5.0, f32::NAN];
let g = cpu_to_gpu(&data, &device).unwrap();
let out = gpu_histc_f32(g.inner(), data.len(), 4, 0.0, 4.0, &device).unwrap();
let got = read_f32(&out, &device);
assert_eq!(got, vec![1.0, 1.0, 1.0, 1.0]);
}
#[test]
fn histc_f64_matches_torch_bins() {
let device = match GpuDevice::new(0) {
Ok(d) => d,
Err(_) => return,
};
let data = [0.0f64, 0.25, 0.5, 0.75, 1.0];
let g = cpu_to_gpu(&data, &device).unwrap();
let out = gpu_histc_f64(g.inner(), data.len(), 4, 0.0, 1.0, &device).unwrap();
let got = read_f64(&out, &device);
assert_eq!(got, vec![1.0, 1.0, 1.0, 2.0]);
assert_eq!(got, cpu_histc_ref(&data, 4, 0.0, 1.0));
}
fn cpu_meshgrid_axis(vec: &[f64], shapes: &[usize], axis: usize) -> Vec<f64> {
let total: usize = shapes.iter().product();
let inner: usize = shapes[axis + 1..].iter().product();
(0..total)
.map(|flat| vec[(flat / inner.max(1)) % shapes[axis]])
.collect()
}
#[test]
fn meshgrid_f32_two_axis_ij() {
let device = match GpuDevice::new(0) {
Ok(d) => d,
Err(_) => return,
};
let a = [1.0f32, 2.0, 3.0];
let b = [4.0f32, 5.0];
let shapes = [3usize, 2];
let ga = cpu_to_gpu(&a, &device).unwrap();
let gb = cpu_to_gpu(&b, &device).unwrap();
let total = 6;
let g0 = gpu_meshgrid_f32(ga.inner(), total, 2, 3, &device).unwrap();
assert_eq!(g0.len(), total);
let h0 = read_f32(&g0, &device);
assert_eq!(h0, vec![1.0, 1.0, 2.0, 2.0, 3.0, 3.0]);
let g1 = gpu_meshgrid_f32(gb.inner(), total, 1, 2, &device).unwrap();
let h1 = read_f32(&g1, &device);
assert_eq!(h1, vec![4.0, 5.0, 4.0, 5.0, 4.0, 5.0]);
let a64: Vec<f64> = a.iter().map(|&x| x as f64).collect();
let b64: Vec<f64> = b.iter().map(|&x| x as f64).collect();
assert_eq!(
h0.iter().map(|&x| x as f64).collect::<Vec<_>>(),
cpu_meshgrid_axis(&a64, &shapes, 0)
);
assert_eq!(
h1.iter().map(|&x| x as f64).collect::<Vec<_>>(),
cpu_meshgrid_axis(&b64, &shapes, 1)
);
}
#[test]
fn meshgrid_f64_three_axis_ij() {
let device = match GpuDevice::new(0) {
Ok(d) => d,
Err(_) => return,
};
let a = [10.0f64, 20.0];
let shapes = [2usize, 3, 2];
let total = 12;
let ga = cpu_to_gpu(&a, &device).unwrap();
let g0 = gpu_meshgrid_f64(ga.inner(), total, 6, 2, &device).unwrap();
let h0 = read_f64(&g0, &device);
assert_eq!(h0, cpu_meshgrid_axis(&a, &shapes, 0));
assert_eq!(
h0,
vec![
10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 20.0, 20.0, 20.0, 20.0, 20.0, 20.0
]
);
}
}