#![cfg(feature = "cuda")]
use cudarc::driver::{CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg, ValidAsZeroBits};
use crate::device::GpuDevice;
use crate::error::{GpuError, GpuResult};
use crate::module_cache::get_or_compile;
const BLOCK_SIZE: u32 = 256;
fn launch_1d(n: usize) -> LaunchConfig {
let grid = ((n as u32).saturating_add(BLOCK_SIZE - 1)) / BLOCK_SIZE;
LaunchConfig {
grid_dim: (grid.max(1), 1, 1),
block_dim: (BLOCK_SIZE, 1, 1),
shared_mem_bytes: 0,
}
}
const MASKED_FILL_F32_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry masked_fill_f32_kernel(
.param .u64 in_ptr, .param .u64 mask_ptr, .param .u64 out_ptr,
.param .f32 value, .param .u32 n
) {
.reg .u32 %idx, %bid, %bdim, %nr;
.reg .u64 %in, %mk, %out, %ioff, %moff;
.reg .f32 %v, %iv; .reg .u16 %m; .reg .pred %p, %sel;
ld.param.u64 %in, [in_ptr]; ld.param.u64 %mk, [mask_ptr];
ld.param.u64 %out, [out_ptr]; ld.param.f32 %v, [value]; ld.param.u32 %nr, [n];
mov.u32 %bid, %ctaid.x; mov.u32 %bdim, %ntid.x; mov.u32 %idx, %tid.x;
mad.lo.u32 %idx, %bid, %bdim, %idx;
setp.ge.u32 %p, %idx, %nr; @%p bra DONE;
cvt.u64.u32 %moff, %idx; add.u64 %mk, %mk, %moff;
ld.global.u8 %m, [%mk]; setp.ne.u16 %sel, %m, 0;
cvt.u64.u32 %ioff, %idx; shl.b64 %ioff, %ioff, 2;
add.u64 %in, %in, %ioff; ld.global.f32 %iv, [%in];
selp.f32 %iv, %v, %iv, %sel;
add.u64 %out, %out, %ioff; st.global.f32 [%out], %iv;
DONE: ret;
}
";
const MASKED_FILL_F64_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry masked_fill_f64_kernel(
.param .u64 in_ptr, .param .u64 mask_ptr, .param .u64 out_ptr,
.param .f64 value, .param .u32 n
) {
.reg .u32 %idx, %bid, %bdim, %nr;
.reg .u64 %in, %mk, %out, %ioff, %moff;
.reg .f64 %v, %iv; .reg .u16 %m; .reg .pred %p, %sel;
ld.param.u64 %in, [in_ptr]; ld.param.u64 %mk, [mask_ptr];
ld.param.u64 %out, [out_ptr]; ld.param.f64 %v, [value]; ld.param.u32 %nr, [n];
mov.u32 %bid, %ctaid.x; mov.u32 %bdim, %ntid.x; mov.u32 %idx, %tid.x;
mad.lo.u32 %idx, %bid, %bdim, %idx;
setp.ge.u32 %p, %idx, %nr; @%p bra DONE;
cvt.u64.u32 %moff, %idx; add.u64 %mk, %mk, %moff;
ld.global.u8 %m, [%mk]; setp.ne.u16 %sel, %m, 0;
cvt.u64.u32 %ioff, %idx; shl.b64 %ioff, %ioff, 3;
add.u64 %in, %in, %ioff; ld.global.f64 %iv, [%in];
selp.f64 %iv, %v, %iv, %sel;
add.u64 %out, %out, %ioff; st.global.f64 [%out], %iv;
DONE: ret;
}
";
const MASKED_FILL_F16_PTX: &str = "\
.version 7.0
.target sm_53
.address_size 64
.visible .entry masked_fill_f16_kernel(
.param .u64 in_ptr, .param .u64 mask_ptr, .param .u64 out_ptr,
.param .f32 value, .param .u32 n
) {
.reg .u32 %idx, %bid, %bdim, %nr;
.reg .u64 %in, %mk, %out, %ioff, %moff;
.reg .f32 %v; .reg .b16 %vh, %iv; .reg .u16 %m; .reg .pred %p, %sel;
ld.param.u64 %in, [in_ptr]; ld.param.u64 %mk, [mask_ptr];
ld.param.u64 %out, [out_ptr]; ld.param.f32 %v, [value]; ld.param.u32 %nr, [n];
cvt.rn.f16.f32 %vh, %v;
mov.u32 %bid, %ctaid.x; mov.u32 %bdim, %ntid.x; mov.u32 %idx, %tid.x;
mad.lo.u32 %idx, %bid, %bdim, %idx;
setp.ge.u32 %p, %idx, %nr; @%p bra DONE;
cvt.u64.u32 %moff, %idx; add.u64 %mk, %mk, %moff;
ld.global.u8 %m, [%mk]; setp.ne.u16 %sel, %m, 0;
cvt.u64.u32 %ioff, %idx; shl.b64 %ioff, %ioff, 1;
add.u64 %in, %in, %ioff; ld.global.b16 %iv, [%in];
selp.b16 %iv, %vh, %iv, %sel;
add.u64 %out, %out, %ioff; st.global.b16 [%out], %iv;
DONE: ret;
}
";
const MASKED_FILL_BF16_PTX: &str = "\
.version 7.8
.target sm_80
.address_size 64
.visible .entry masked_fill_bf16_kernel(
.param .u64 in_ptr, .param .u64 mask_ptr, .param .u64 out_ptr,
.param .f32 value, .param .u32 n
) {
.reg .u32 %idx, %bid, %bdim, %nr;
.reg .u64 %in, %mk, %out, %ioff, %moff;
.reg .f32 %v; .reg .b16 %vh, %iv; .reg .u16 %m; .reg .pred %p, %sel;
ld.param.u64 %in, [in_ptr]; ld.param.u64 %mk, [mask_ptr];
ld.param.u64 %out, [out_ptr]; ld.param.f32 %v, [value]; ld.param.u32 %nr, [n];
cvt.rn.bf16.f32 %vh, %v;
mov.u32 %bid, %ctaid.x; mov.u32 %bdim, %ntid.x; mov.u32 %idx, %tid.x;
mad.lo.u32 %idx, %bid, %bdim, %idx;
setp.ge.u32 %p, %idx, %nr; @%p bra DONE;
cvt.u64.u32 %moff, %idx; add.u64 %mk, %mk, %moff;
ld.global.u8 %m, [%mk]; setp.ne.u16 %sel, %m, 0;
cvt.u64.u32 %ioff, %idx; shl.b64 %ioff, %ioff, 1;
add.u64 %in, %in, %ioff; ld.global.b16 %iv, [%in];
selp.b16 %iv, %vh, %iv, %sel;
add.u64 %out, %out, %ioff; st.global.b16 [%out], %iv;
DONE: ret;
}
";
const MASKED_FILL_I32_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry masked_fill_i32_kernel(
.param .u64 in_ptr, .param .u64 mask_ptr, .param .u64 out_ptr,
.param .u32 value, .param .u32 n
) {
.reg .u32 %idx, %bid, %bdim, %nr;
.reg .u64 %in, %mk, %out, %ioff, %moff;
.reg .b32 %v, %iv; .reg .u16 %m; .reg .pred %p, %sel;
ld.param.u64 %in, [in_ptr]; ld.param.u64 %mk, [mask_ptr];
ld.param.u64 %out, [out_ptr]; ld.param.b32 %v, [value]; ld.param.u32 %nr, [n];
mov.u32 %bid, %ctaid.x; mov.u32 %bdim, %ntid.x; mov.u32 %idx, %tid.x;
mad.lo.u32 %idx, %bid, %bdim, %idx;
setp.ge.u32 %p, %idx, %nr; @%p bra DONE;
cvt.u64.u32 %moff, %idx; add.u64 %mk, %mk, %moff;
ld.global.u8 %m, [%mk]; setp.ne.u16 %sel, %m, 0;
cvt.u64.u32 %ioff, %idx; shl.b64 %ioff, %ioff, 2;
add.u64 %in, %in, %ioff; ld.global.b32 %iv, [%in];
selp.b32 %iv, %v, %iv, %sel;
add.u64 %out, %out, %ioff; st.global.b32 [%out], %iv;
DONE: ret;
}
";
const MASKED_FILL_I64_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry masked_fill_i64_kernel(
.param .u64 in_ptr, .param .u64 mask_ptr, .param .u64 out_ptr,
.param .u64 value, .param .u32 n
) {
.reg .u32 %idx, %bid, %bdim, %nr;
.reg .u64 %in, %mk, %out, %ioff, %moff;
.reg .b64 %v, %iv; .reg .u16 %m; .reg .pred %p, %sel;
ld.param.u64 %in, [in_ptr]; ld.param.u64 %mk, [mask_ptr];
ld.param.u64 %out, [out_ptr]; ld.param.b64 %v, [value]; ld.param.u32 %nr, [n];
mov.u32 %bid, %ctaid.x; mov.u32 %bdim, %ntid.x; mov.u32 %idx, %tid.x;
mad.lo.u32 %idx, %bid, %bdim, %idx;
setp.ge.u32 %p, %idx, %nr; @%p bra DONE;
cvt.u64.u32 %moff, %idx; add.u64 %mk, %mk, %moff;
ld.global.u8 %m, [%mk]; setp.ne.u16 %sel, %m, 0;
cvt.u64.u32 %ioff, %idx; shl.b64 %ioff, %ioff, 3;
add.u64 %in, %in, %ioff; ld.global.b64 %iv, [%in];
selp.b64 %iv, %v, %iv, %sel;
add.u64 %out, %out, %ioff; st.global.b64 [%out], %iv;
DONE: ret;
}
";
const WHERE_32_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry where_32_kernel(
.param .u64 cond_ptr, .param .u64 x_ptr, .param .u64 y_ptr,
.param .u64 out_ptr, .param .u32 n
) {
.reg .u32 %idx, %bid, %bdim, %nr;
.reg .u64 %cd, %x, %y, %out, %voff, %coff;
.reg .b32 %xv, %yv, %r; .reg .u16 %c; .reg .pred %p, %sel;
ld.param.u64 %cd, [cond_ptr]; ld.param.u64 %x, [x_ptr]; ld.param.u64 %y, [y_ptr];
ld.param.u64 %out, [out_ptr]; ld.param.u32 %nr, [n];
mov.u32 %bid, %ctaid.x; mov.u32 %bdim, %ntid.x; mov.u32 %idx, %tid.x;
mad.lo.u32 %idx, %bid, %bdim, %idx;
setp.ge.u32 %p, %idx, %nr; @%p bra DONE;
cvt.u64.u32 %coff, %idx; add.u64 %cd, %cd, %coff;
ld.global.u8 %c, [%cd]; setp.ne.u16 %sel, %c, 0;
cvt.u64.u32 %voff, %idx; shl.b64 %voff, %voff, 2;
add.u64 %x, %x, %voff; ld.global.b32 %xv, [%x];
add.u64 %y, %y, %voff; ld.global.b32 %yv, [%y];
selp.b32 %r, %xv, %yv, %sel;
add.u64 %out, %out, %voff; st.global.b32 [%out], %r;
DONE: ret;
}
";
const WHERE_64_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry where_64_kernel(
.param .u64 cond_ptr, .param .u64 x_ptr, .param .u64 y_ptr,
.param .u64 out_ptr, .param .u32 n
) {
.reg .u32 %idx, %bid, %bdim, %nr;
.reg .u64 %cd, %x, %y, %out, %voff, %coff;
.reg .b64 %xv, %yv, %r; .reg .u16 %c; .reg .pred %p, %sel;
ld.param.u64 %cd, [cond_ptr]; ld.param.u64 %x, [x_ptr]; ld.param.u64 %y, [y_ptr];
ld.param.u64 %out, [out_ptr]; ld.param.u32 %nr, [n];
mov.u32 %bid, %ctaid.x; mov.u32 %bdim, %ntid.x; mov.u32 %idx, %tid.x;
mad.lo.u32 %idx, %bid, %bdim, %idx;
setp.ge.u32 %p, %idx, %nr; @%p bra DONE;
cvt.u64.u32 %coff, %idx; add.u64 %cd, %cd, %coff;
ld.global.u8 %c, [%cd]; setp.ne.u16 %sel, %c, 0;
cvt.u64.u32 %voff, %idx; shl.b64 %voff, %voff, 3;
add.u64 %x, %x, %voff; ld.global.b64 %xv, [%x];
add.u64 %y, %y, %voff; ld.global.b64 %yv, [%y];
selp.b64 %r, %xv, %yv, %sel;
add.u64 %out, %out, %voff; st.global.b64 [%out], %r;
DONE: ret;
}
";
const WHERE_16_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry where_16_kernel(
.param .u64 cond_ptr, .param .u64 x_ptr, .param .u64 y_ptr,
.param .u64 out_ptr, .param .u32 n
) {
.reg .u32 %idx, %bid, %bdim, %nr;
.reg .u64 %cd, %x, %y, %out, %voff, %coff;
.reg .b16 %xv, %yv, %r; .reg .u16 %c; .reg .pred %p, %sel;
ld.param.u64 %cd, [cond_ptr]; ld.param.u64 %x, [x_ptr]; ld.param.u64 %y, [y_ptr];
ld.param.u64 %out, [out_ptr]; ld.param.u32 %nr, [n];
mov.u32 %bid, %ctaid.x; mov.u32 %bdim, %ntid.x; mov.u32 %idx, %tid.x;
mad.lo.u32 %idx, %bid, %bdim, %idx;
setp.ge.u32 %p, %idx, %nr; @%p bra DONE;
cvt.u64.u32 %coff, %idx; add.u64 %cd, %cd, %coff;
ld.global.u8 %c, [%cd]; setp.ne.u16 %sel, %c, 0;
cvt.u64.u32 %voff, %idx; shl.b64 %voff, %voff, 1;
add.u64 %x, %x, %voff; ld.global.b16 %xv, [%x];
add.u64 %y, %y, %voff; ld.global.b16 %yv, [%y];
selp.b16 %r, %xv, %yv, %sel;
add.u64 %out, %out, %voff; st.global.b16 [%out], %r;
DONE: ret;
}
";
const COUNT_TRUE_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry count_true_kernel(.param .u64 mask_ptr, .param .u64 out_ptr, .param .u32 n) {
.reg .u32 %idx, %bid, %bdim, %nr, %i;
.reg .u64 %mk, %out, %off, %cur;
.reg .u16 %v; .reg .s32 %acc, %one;
.reg .pred %only0, %p, %nz;
ld.param.u64 %mk, [mask_ptr]; ld.param.u64 %out, [out_ptr]; ld.param.u32 %nr, [n];
mov.u32 %bid, %ctaid.x; mov.u32 %bdim, %ntid.x; mov.u32 %idx, %tid.x;
mad.lo.u32 %idx, %bid, %bdim, %idx;
setp.ne.u32 %only0, %idx, 0; @%only0 bra DONE;
mov.s32 %acc, 0; mov.u32 %i, 0;
LOOP:
setp.ge.u32 %p, %i, %nr; @%p bra STORE;
cvt.u64.u32 %off, %i; add.u64 %cur, %mk, %off;
ld.global.u8 %v, [%cur];
setp.ne.u16 %nz, %v, 0; selp.s32 %one, 1, 0, %nz;
add.s32 %acc, %acc, %one;
add.u32 %i, %i, 1; bra LOOP;
STORE:
st.global.s32 [%out], %acc;
DONE: ret;
}
";
fn compact_ptx(kernel_name: &str, val_shift: u32, ld_st_ty: &str, reg_decl: &str) -> String {
format!(
"\
.version 7.0
.target sm_52
.address_size 64
.visible .entry {kernel_name}(.param .u64 in_ptr, .param .u64 mask_ptr, .param .u64 out_ptr, .param .u32 n) {{
.reg .u32 %idx, %bid, %bdim, %nr, %i, %j;
.reg .u64 %in, %mk, %out, %ioff, %ooff, %icur, %mcur, %ocur;
.reg .u16 %m; {reg_decl}
.reg .pred %only0, %p, %nz;
ld.param.u64 %in, [in_ptr]; ld.param.u64 %mk, [mask_ptr];
ld.param.u64 %out, [out_ptr]; ld.param.u32 %nr, [n];
mov.u32 %bid, %ctaid.x; mov.u32 %bdim, %ntid.x; mov.u32 %idx, %tid.x;
mad.lo.u32 %idx, %bid, %bdim, %idx;
setp.ne.u32 %only0, %idx, 0; @%only0 bra DONE;
mov.u32 %i, 0; mov.u32 %j, 0;
LOOP:
setp.ge.u32 %p, %i, %nr; @%p bra DONE;
cvt.u64.u32 %ioff, %i; add.u64 %mcur, %mk, %ioff;
ld.global.u8 %m, [%mcur]; setp.ne.u16 %nz, %m, 0;
@!%nz bra NEXT;
// mask[i] true: out[j] = input[i]
shl.b64 %ioff, %ioff, {val_shift}; add.u64 %icur, %in, %ioff;
ld.global.{ld_st_ty} %val, [%icur];
cvt.u64.u32 %ooff, %j; shl.b64 %ooff, %ooff, {val_shift}; add.u64 %ocur, %out, %ooff;
st.global.{ld_st_ty} [%ocur], %val;
add.u32 %j, %j, 1;
NEXT:
add.u32 %i, %i, 1; bra LOOP;
DONE: ret;
}}
"
)
}
fn scatter_ptx(kernel_name: &str, val_shift: u32, ld_st_ty: &str, reg_decl: &str) -> String {
format!(
"\
.version 7.0
.target sm_52
.address_size 64
.visible .entry {kernel_name}(.param .u64 grad_ptr, .param .u64 mask_ptr, .param .u64 out_ptr, .param .u32 n) {{
.reg .u32 %idx, %bid, %bdim, %nr, %i, %j;
.reg .u64 %gr, %mk, %out, %goff, %ooff, %gcur, %mcur, %ocur;
.reg .u16 %m; {reg_decl}
.reg .pred %only0, %p, %nz;
ld.param.u64 %gr, [grad_ptr]; ld.param.u64 %mk, [mask_ptr];
ld.param.u64 %out, [out_ptr]; ld.param.u32 %nr, [n];
mov.u32 %bid, %ctaid.x; mov.u32 %bdim, %ntid.x; mov.u32 %idx, %tid.x;
mad.lo.u32 %idx, %bid, %bdim, %idx;
setp.ne.u32 %only0, %idx, 0; @%only0 bra DONE;
mov.u32 %i, 0; mov.u32 %j, 0;
LOOP:
setp.ge.u32 %p, %i, %nr; @%p bra DONE;
cvt.u64.u32 %ooff, %i; add.u64 %mcur, %mk, %ooff;
ld.global.u8 %m, [%mcur]; setp.ne.u16 %nz, %m, 0;
@!%nz bra NEXT;
// mask[i] true: out[i] = grad[j]
cvt.u64.u32 %goff, %j; shl.b64 %goff, %goff, {val_shift}; add.u64 %gcur, %gr, %goff;
ld.global.{ld_st_ty} %val, [%gcur];
shl.b64 %ooff, %ooff, {val_shift}; add.u64 %ocur, %out, %ooff;
st.global.{ld_st_ty} [%ocur], %val;
add.u32 %j, %j, 1;
NEXT:
add.u32 %i, %i, 1; bra LOOP;
DONE: ret;
}}
"
)
}
fn scatter_forward_ptx(
kernel_name: &str,
val_shift: u32,
ld_st_ty: &str,
reg_decl: &str,
) -> String {
format!(
"\
.version 7.0
.target sm_52
.address_size 64
.visible .entry {kernel_name}(.param .u64 in_ptr, .param .u64 src_ptr, .param .u64 mask_ptr, .param .u64 out_ptr, .param .u32 n) {{
.reg .u32 %idx, %bid, %bdim, %nr, %i, %j;
.reg .u64 %in, %sr, %mk, %out, %eoff, %soff, %icur, %scur, %mcur, %ocur;
.reg .u16 %m; {reg_decl}
.reg .pred %only0, %p, %nz;
ld.param.u64 %in, [in_ptr]; ld.param.u64 %sr, [src_ptr];
ld.param.u64 %mk, [mask_ptr]; ld.param.u64 %out, [out_ptr]; ld.param.u32 %nr, [n];
mov.u32 %bid, %ctaid.x; mov.u32 %bdim, %ntid.x; mov.u32 %idx, %tid.x;
mad.lo.u32 %idx, %bid, %bdim, %idx;
setp.ne.u32 %only0, %idx, 0; @%only0 bra DONE;
mov.u32 %i, 0; mov.u32 %j, 0;
LOOP:
setp.ge.u32 %p, %i, %nr; @%p bra DONE;
cvt.u64.u32 %eoff, %i; add.u64 %mcur, %mk, %eoff;
ld.global.u8 %m, [%mcur]; setp.ne.u16 %nz, %m, 0;
shl.b64 %eoff, %eoff, {val_shift};
@!%nz bra COPYIN;
// mask[i] true: out[i] = src[j]; j++
cvt.u64.u32 %soff, %j; shl.b64 %soff, %soff, {val_shift}; add.u64 %scur, %sr, %soff;
ld.global.{ld_st_ty} %val, [%scur];
add.u64 %ocur, %out, %eoff; st.global.{ld_st_ty} [%ocur], %val;
add.u32 %j, %j, 1;
bra NEXT;
COPYIN:
// mask[i] false: out[i] = in[i]
add.u64 %icur, %in, %eoff; ld.global.{ld_st_ty} %val, [%icur];
add.u64 %ocur, %out, %eoff; st.global.{ld_st_ty} [%ocur], %val;
NEXT:
add.u32 %i, %i, 1; bra LOOP;
DONE: ret;
}}
"
)
}
#[allow(clippy::too_many_arguments)]
fn launch_masked_fill<T, S>(
input: &CudaSlice<T>,
mask: &CudaSlice<u8>,
value: S,
n: usize,
device: &GpuDevice,
ptx: &'static str,
kernel_name: &'static str,
) -> GpuResult<CudaSlice<T>>
where
T: DeviceRepr + ValidAsZeroBits,
S: DeviceRepr,
{
if input.len() < n || mask.len() < n {
return Err(GpuError::LengthMismatch {
a: input.len().min(mask.len()),
b: n,
});
}
let stream = device.stream();
if n == 0 {
return Ok(stream.alloc_zeros::<T>(0)?);
}
let ctx = device.context();
let f = get_or_compile(ctx, ptx, kernel_name, device.ordinal() as u32).map_err(|e| {
GpuError::PtxCompileFailed {
kernel: kernel_name,
source: e,
}
})?;
let mut out = stream.alloc_zeros::<T>(n)?;
let cfg = launch_1d(n);
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(input)
.arg(mask)
.arg(&mut out)
.arg(&value)
.arg(&n_u32)
.launch(cfg)?;
}
Ok(out)
}
fn launch_where<T: DeviceRepr + ValidAsZeroBits>(
cond: &CudaSlice<u8>,
x: &CudaSlice<T>,
y: &CudaSlice<T>,
n: usize,
device: &GpuDevice,
ptx: &'static str,
kernel_name: &'static str,
) -> GpuResult<CudaSlice<T>> {
if x.len() < n || y.len() < n || cond.len() < n {
return Err(GpuError::LengthMismatch {
a: x.len().min(y.len()).min(cond.len()),
b: n,
});
}
let stream = device.stream();
if n == 0 {
return Ok(stream.alloc_zeros::<T>(0)?);
}
let ctx = device.context();
let f = get_or_compile(ctx, ptx, kernel_name, device.ordinal() as u32).map_err(|e| {
GpuError::PtxCompileFailed {
kernel: kernel_name,
source: e,
}
})?;
let mut out = stream.alloc_zeros::<T>(n)?;
let cfg = launch_1d(n);
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(cond)
.arg(x)
.arg(y)
.arg(&mut out)
.arg(&n_u32)
.launch(cfg)?;
}
Ok(out)
}
pub fn count_true(mask: &CudaSlice<u8>, device: &GpuDevice) -> GpuResult<usize> {
let n = mask.len();
let stream = device.stream();
if n == 0 {
return Ok(0);
}
let ctx = device.context();
let f = get_or_compile(
ctx,
COUNT_TRUE_PTX,
"count_true_kernel",
device.ordinal() as u32,
)
.map_err(|e| GpuError::PtxCompileFailed {
kernel: "count_true_kernel",
source: e,
})?;
let mut out = stream.alloc_zeros::<i32>(1)?;
let cfg = LaunchConfig {
grid_dim: (1, 1, 1),
block_dim: (1, 1, 1),
shared_mem_bytes: 0,
};
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(mask)
.arg(&mut out)
.arg(&n_u32)
.launch(cfg)?;
}
let host = stream.clone_dtoh(&out)?;
Ok(host[0].max(0) as usize)
}
fn launch_compact<T: DeviceRepr + ValidAsZeroBits>(
input: &CudaSlice<T>,
mask: &CudaSlice<u8>,
out_len: usize,
device: &GpuDevice,
ptx: String,
kernel_name: String,
) -> GpuResult<CudaSlice<T>> {
let n = input.len();
let stream = device.stream();
if out_len == 0 {
return Ok(stream.alloc_zeros::<T>(0)?);
}
let ctx = device.context();
let f =
crate::module_cache::get_or_compile_owned(ctx, ptx, kernel_name, device.ordinal() as u32)
.map_err(|e| GpuError::PtxCompileFailed {
kernel: "masked_select_compact",
source: e,
})?;
let mut out = stream.alloc_zeros::<T>(out_len)?;
let cfg = LaunchConfig {
grid_dim: (1, 1, 1),
block_dim: (1, 1, 1),
shared_mem_bytes: 0,
};
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(input)
.arg(mask)
.arg(&mut out)
.arg(&n_u32)
.launch(cfg)?;
}
Ok(out)
}
fn launch_scatter<T: DeviceRepr + ValidAsZeroBits>(
grad: &CudaSlice<T>,
mask: &CudaSlice<u8>,
out_numel: usize,
device: &GpuDevice,
ptx: String,
kernel_name: String,
) -> GpuResult<CudaSlice<T>> {
if mask.len() != out_numel {
return Err(GpuError::LengthMismatch {
a: mask.len(),
b: out_numel,
});
}
let stream = device.stream();
if out_numel == 0 {
return Ok(stream.alloc_zeros::<T>(0)?);
}
let ctx = device.context();
let f =
crate::module_cache::get_or_compile_owned(ctx, ptx, kernel_name, device.ordinal() as u32)
.map_err(|e| GpuError::PtxCompileFailed {
kernel: "masked_scatter",
source: e,
})?;
let mut out = stream.alloc_zeros::<T>(out_numel)?;
let cfg = LaunchConfig {
grid_dim: (1, 1, 1),
block_dim: (1, 1, 1),
shared_mem_bytes: 0,
};
let n_u32 = out_numel as u32;
unsafe {
stream
.launch_builder(&f)
.arg(grad)
.arg(mask)
.arg(&mut out)
.arg(&n_u32)
.launch(cfg)?;
}
Ok(out)
}
fn launch_scatter_forward<T: DeviceRepr + ValidAsZeroBits>(
input: &CudaSlice<T>,
source: &CudaSlice<T>,
mask: &CudaSlice<u8>,
n: usize,
device: &GpuDevice,
ptx: String,
kernel_name: String,
) -> GpuResult<CudaSlice<T>> {
if input.len() < n || mask.len() < n {
return Err(GpuError::LengthMismatch {
a: input.len().min(mask.len()),
b: n,
});
}
let stream = device.stream();
if n == 0 {
return Ok(stream.alloc_zeros::<T>(0)?);
}
let ctx = device.context();
let f =
crate::module_cache::get_or_compile_owned(ctx, ptx, kernel_name, device.ordinal() as u32)
.map_err(|e| GpuError::PtxCompileFailed {
kernel: "masked_scatter_forward",
source: e,
})?;
let mut out = stream.alloc_zeros::<T>(n)?;
let cfg = LaunchConfig {
grid_dim: (1, 1, 1),
block_dim: (1, 1, 1),
shared_mem_bytes: 0,
};
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(input)
.arg(source)
.arg(mask)
.arg(&mut out)
.arg(&n_u32)
.launch(cfg)?;
}
Ok(out)
}
fn isfinite_ptx(kernel_name: &str, ty: &str, in_shift: u32, inf_lit: &str) -> String {
format!(
"\
.version 7.0
.target sm_52
.address_size 64
.visible .entry {kernel_name}(
.param .u64 in_ptr, .param .u64 out_ptr, .param .u32 n
) {{
.reg .u32 %idx, %bid, %bdim, %nr;
.reg .u64 %a, %out, %ioff, %ooff;
.reg .{ty} %v, %av, %inf;
.reg .u16 %res;
.reg .pred %p, %notnan, %notinf, %fin;
ld.param.u64 %a, [in_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %nr, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %idx, %tid.x;
mad.lo.u32 %idx, %bid, %bdim, %idx;
setp.ge.u32 %p, %idx, %nr;
@%p bra DONE;
cvt.u64.u32 %ioff, %idx;
shl.b64 %ioff, %ioff, {in_shift};
add.u64 %a, %a, %ioff;
cvt.u64.u32 %ooff, %idx;
add.u64 %out, %out, %ooff;
ld.global.{ty} %v, [%a];
// not-NaN: v == v (setp.eq is unordered-false, so NaN -> false)
setp.eq.{ty} %notnan, %v, %v;
// |v| != inf
abs.{ty} %av, %v;
mov.{ty} %inf, {inf_lit};
setp.ne.{ty} %notinf, %av, %inf;
and.pred %fin, %notnan, %notinf;
selp.u16 %res, 1, 0, %fin;
st.global.u8 [%out], %res;
DONE:
ret;
}}
"
)
}
fn ne_scalar_ptx(kernel_name: &str, ty: &str, in_shift: u32) -> String {
format!(
"\
.version 7.0
.target sm_52
.address_size 64
.visible .entry {kernel_name}(
.param .u64 in_ptr, .param .u64 out_ptr, .param .{ty} value, .param .u32 n
) {{
.reg .u32 %idx, %bid, %bdim, %nr;
.reg .u64 %a, %out, %ioff, %ooff;
.reg .{ty} %v, %val;
.reg .u16 %res;
.reg .pred %p, %c;
ld.param.u64 %a, [in_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.{ty} %val, [value];
ld.param.u32 %nr, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %idx, %tid.x;
mad.lo.u32 %idx, %bid, %bdim, %idx;
setp.ge.u32 %p, %idx, %nr;
@%p bra DONE;
cvt.u64.u32 %ioff, %idx;
shl.b64 %ioff, %ioff, {in_shift};
add.u64 %a, %a, %ioff;
cvt.u64.u32 %ooff, %idx;
add.u64 %out, %out, %ooff;
ld.global.{ty} %v, [%a];
// setp.neu.f is the UNORDERED not-equal: NaN != value -> true (matches the
// CPU `v != value` walk where Rust `NaN != x` is true). Plain `setp.ne.f`
// is the *ordered* form (NaN -> false), which would diverge from the CPU
// reference, so `.neu` is required here.
setp.neu.{ty} %c, %v, %val;
selp.u16 %res, 1, 0, %c;
st.global.u8 [%out], %res;
DONE:
ret;
}}
"
)
}
fn launch_predicate<T: DeviceRepr + ValidAsZeroBits>(
input: &CudaSlice<T>,
device: &GpuDevice,
ptx: String,
kernel_name: String,
) -> GpuResult<CudaSlice<u8>> {
let n = input.len();
let stream = device.stream();
if n == 0 {
return Ok(stream.alloc_zeros::<u8>(0)?);
}
let ctx = device.context();
let f =
crate::module_cache::get_or_compile_owned(ctx, ptx, kernel_name, device.ordinal() as u32)
.map_err(|e| GpuError::PtxCompileFailed {
kernel: "masked_predicate",
source: e,
})?;
let mut out = stream.alloc_zeros::<u8>(n)?;
let cfg = launch_1d(n);
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(input)
.arg(&mut out)
.arg(&n_u32)
.launch(cfg)?;
}
Ok(out)
}
fn launch_predicate_scalar<T: DeviceRepr + ValidAsZeroBits, S: DeviceRepr>(
input: &CudaSlice<T>,
value: S,
device: &GpuDevice,
ptx: String,
kernel_name: String,
) -> GpuResult<CudaSlice<u8>> {
let n = input.len();
let stream = device.stream();
if n == 0 {
return Ok(stream.alloc_zeros::<u8>(0)?);
}
let ctx = device.context();
let f =
crate::module_cache::get_or_compile_owned(ctx, ptx, kernel_name, device.ordinal() as u32)
.map_err(|e| GpuError::PtxCompileFailed {
kernel: "masked_predicate_scalar",
source: e,
})?;
let mut out = stream.alloc_zeros::<u8>(n)?;
let cfg = launch_1d(n);
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(input)
.arg(&mut out)
.arg(&value)
.arg(&n_u32)
.launch(cfg)?;
}
Ok(out)
}
pub fn isfinite_mask_f32(input: &CudaSlice<f32>, d: &GpuDevice) -> GpuResult<CudaSlice<u8>> {
let ptx = isfinite_ptx("isfinite_mask_f32_kernel", "f32", 2, "0f7F800000");
launch_predicate(input, d, ptx, "isfinite_mask_f32_kernel".to_string())
}
pub fn isfinite_mask_f64(input: &CudaSlice<f64>, d: &GpuDevice) -> GpuResult<CudaSlice<u8>> {
let ptx = isfinite_ptx("isfinite_mask_f64_kernel", "f64", 3, "0d7FF0000000000000");
launch_predicate(input, d, ptx, "isfinite_mask_f64_kernel".to_string())
}
pub fn ne_scalar_mask_f32(
input: &CudaSlice<f32>,
value: f32,
d: &GpuDevice,
) -> GpuResult<CudaSlice<u8>> {
let ptx = ne_scalar_ptx("ne_scalar_mask_f32_kernel", "f32", 2);
launch_predicate_scalar(
input,
value,
d,
ptx,
"ne_scalar_mask_f32_kernel".to_string(),
)
}
pub fn ne_scalar_mask_f64(
input: &CudaSlice<f64>,
value: f64,
d: &GpuDevice,
) -> GpuResult<CudaSlice<u8>> {
let ptx = ne_scalar_ptx("ne_scalar_mask_f64_kernel", "f64", 3);
launch_predicate_scalar(
input,
value,
d,
ptx,
"ne_scalar_mask_f64_kernel".to_string(),
)
}
pub fn masked_fill_f32(
input: &CudaSlice<f32>,
mask: &CudaSlice<u8>,
value: f32,
n: usize,
d: &GpuDevice,
) -> GpuResult<CudaSlice<f32>> {
launch_masked_fill(
input,
mask,
value,
n,
d,
MASKED_FILL_F32_PTX,
"masked_fill_f32_kernel",
)
}
pub fn masked_fill_f64(
input: &CudaSlice<f64>,
mask: &CudaSlice<u8>,
value: f64,
n: usize,
d: &GpuDevice,
) -> GpuResult<CudaSlice<f64>> {
launch_masked_fill(
input,
mask,
value,
n,
d,
MASKED_FILL_F64_PTX,
"masked_fill_f64_kernel",
)
}
pub fn masked_fill_f16(
input: &CudaSlice<u16>,
mask: &CudaSlice<u8>,
value: f32,
n: usize,
d: &GpuDevice,
) -> GpuResult<CudaSlice<u16>> {
launch_masked_fill(
input,
mask,
value,
n,
d,
MASKED_FILL_F16_PTX,
"masked_fill_f16_kernel",
)
}
pub fn masked_fill_bf16(
input: &CudaSlice<u16>,
mask: &CudaSlice<u8>,
value: f32,
n: usize,
d: &GpuDevice,
) -> GpuResult<CudaSlice<u16>> {
launch_masked_fill(
input,
mask,
value,
n,
d,
MASKED_FILL_BF16_PTX,
"masked_fill_bf16_kernel",
)
}
pub fn masked_fill_i32(
input: &CudaSlice<i32>,
mask: &CudaSlice<u8>,
value: i32,
n: usize,
d: &GpuDevice,
) -> GpuResult<CudaSlice<i32>> {
launch_masked_fill(
input,
mask,
value,
n,
d,
MASKED_FILL_I32_PTX,
"masked_fill_i32_kernel",
)
}
pub fn masked_fill_i64(
input: &CudaSlice<i64>,
mask: &CudaSlice<u8>,
value: i64,
n: usize,
d: &GpuDevice,
) -> GpuResult<CudaSlice<i64>> {
launch_masked_fill(
input,
mask,
value,
n,
d,
MASKED_FILL_I64_PTX,
"masked_fill_i64_kernel",
)
}
pub fn where_32<T: DeviceRepr + ValidAsZeroBits>(
cond: &CudaSlice<u8>,
x: &CudaSlice<T>,
y: &CudaSlice<T>,
n: usize,
d: &GpuDevice,
) -> GpuResult<CudaSlice<T>> {
debug_assert_eq!(
std::mem::size_of::<T>(),
4,
"where_32 requires a 4-byte element"
);
launch_where(cond, x, y, n, d, WHERE_32_PTX, "where_32_kernel")
}
pub fn where_64<T: DeviceRepr + ValidAsZeroBits>(
cond: &CudaSlice<u8>,
x: &CudaSlice<T>,
y: &CudaSlice<T>,
n: usize,
d: &GpuDevice,
) -> GpuResult<CudaSlice<T>> {
debug_assert_eq!(
std::mem::size_of::<T>(),
8,
"where_64 requires an 8-byte element"
);
launch_where(cond, x, y, n, d, WHERE_64_PTX, "where_64_kernel")
}
pub fn where_16(
cond: &CudaSlice<u8>,
x: &CudaSlice<u16>,
y: &CudaSlice<u16>,
n: usize,
d: &GpuDevice,
) -> GpuResult<CudaSlice<u16>> {
launch_where(cond, x, y, n, d, WHERE_16_PTX, "where_16_kernel")
}
pub fn masked_select_32<T: DeviceRepr + ValidAsZeroBits>(
input: &CudaSlice<T>,
mask: &CudaSlice<u8>,
d: &GpuDevice,
) -> GpuResult<(CudaSlice<T>, usize)> {
let len = count_true(mask, d)?;
let ptx = compact_ptx("masked_select_compact_32", 2, "b32", ".reg .b32 %val;");
let out = launch_compact(
input,
mask,
len,
d,
ptx,
"masked_select_compact_32".to_string(),
)?;
Ok((out, len))
}
pub fn masked_select_64<T: DeviceRepr + ValidAsZeroBits>(
input: &CudaSlice<T>,
mask: &CudaSlice<u8>,
d: &GpuDevice,
) -> GpuResult<(CudaSlice<T>, usize)> {
let len = count_true(mask, d)?;
let ptx = compact_ptx("masked_select_compact_64", 3, "b64", ".reg .b64 %val;");
let out = launch_compact(
input,
mask,
len,
d,
ptx,
"masked_select_compact_64".to_string(),
)?;
Ok((out, len))
}
pub fn masked_select_16(
input: &CudaSlice<u16>,
mask: &CudaSlice<u8>,
d: &GpuDevice,
) -> GpuResult<(CudaSlice<u16>, usize)> {
let len = count_true(mask, d)?;
let ptx = compact_ptx("masked_select_compact_16", 1, "b16", ".reg .b16 %val;");
let out = launch_compact(
input,
mask,
len,
d,
ptx,
"masked_select_compact_16".to_string(),
)?;
Ok((out, len))
}
pub fn masked_scatter_32<T: DeviceRepr + ValidAsZeroBits>(
grad: &CudaSlice<T>,
mask: &CudaSlice<u8>,
out_numel: usize,
d: &GpuDevice,
) -> GpuResult<CudaSlice<T>> {
debug_assert_eq!(
std::mem::size_of::<T>(),
4,
"masked_scatter_32 requires a 4-byte element"
);
let ptx = scatter_ptx("masked_scatter_32", 2, "b32", ".reg .b32 %val;");
launch_scatter(
grad,
mask,
out_numel,
d,
ptx,
"masked_scatter_32".to_string(),
)
}
pub fn masked_scatter_64<T: DeviceRepr + ValidAsZeroBits>(
grad: &CudaSlice<T>,
mask: &CudaSlice<u8>,
out_numel: usize,
d: &GpuDevice,
) -> GpuResult<CudaSlice<T>> {
debug_assert_eq!(
std::mem::size_of::<T>(),
8,
"masked_scatter_64 requires an 8-byte element"
);
let ptx = scatter_ptx("masked_scatter_64", 3, "b64", ".reg .b64 %val;");
launch_scatter(
grad,
mask,
out_numel,
d,
ptx,
"masked_scatter_64".to_string(),
)
}
pub fn masked_scatter_16(
grad: &CudaSlice<u16>,
mask: &CudaSlice<u8>,
out_numel: usize,
d: &GpuDevice,
) -> GpuResult<CudaSlice<u16>> {
let ptx = scatter_ptx("masked_scatter_16", 1, "b16", ".reg .b16 %val;");
launch_scatter(
grad,
mask,
out_numel,
d,
ptx,
"masked_scatter_16".to_string(),
)
}
pub fn masked_scatter_forward_32<T: DeviceRepr + ValidAsZeroBits>(
input: &CudaSlice<T>,
source: &CudaSlice<T>,
mask: &CudaSlice<u8>,
n: usize,
d: &GpuDevice,
) -> GpuResult<CudaSlice<T>> {
debug_assert_eq!(
std::mem::size_of::<T>(),
4,
"masked_scatter_forward_32 requires a 4-byte element"
);
let ptx = scatter_forward_ptx("masked_scatter_forward_32", 2, "b32", ".reg .b32 %val;");
launch_scatter_forward(
input,
source,
mask,
n,
d,
ptx,
"masked_scatter_forward_32".to_string(),
)
}
pub fn masked_scatter_forward_64<T: DeviceRepr + ValidAsZeroBits>(
input: &CudaSlice<T>,
source: &CudaSlice<T>,
mask: &CudaSlice<u8>,
n: usize,
d: &GpuDevice,
) -> GpuResult<CudaSlice<T>> {
debug_assert_eq!(
std::mem::size_of::<T>(),
8,
"masked_scatter_forward_64 requires an 8-byte element"
);
let ptx = scatter_forward_ptx("masked_scatter_forward_64", 3, "b64", ".reg .b64 %val;");
launch_scatter_forward(
input,
source,
mask,
n,
d,
ptx,
"masked_scatter_forward_64".to_string(),
)
}
#[cfg(test)]
mod tests {
use super::*;
fn dev() -> GpuDevice {
GpuDevice::new(0).expect("cuda device")
}
#[test]
fn masked_fill_f32_replaces_true_positions() {
let d = dev();
let input = d.stream().clone_htod(&vec![1.0f32, 2.0, 3.0, 4.0]).unwrap();
let mask = d.stream().clone_htod(&vec![0u8, 1, 0, 1]).unwrap();
let r = masked_fill_f32(&input, &mask, -9.0, 4, &d).unwrap();
assert_eq!(
d.stream().clone_dtoh(&r).unwrap(),
vec![1.0f32, -9.0, 3.0, -9.0]
);
}
#[test]
fn where_32_selects() {
let d = dev();
let cond = d.stream().clone_htod(&vec![1u8, 0, 1, 0]).unwrap();
let x = d
.stream()
.clone_htod(&vec![10.0f32, 20.0, 30.0, 40.0])
.unwrap();
let y = d
.stream()
.clone_htod(&vec![-1.0f32, -2.0, -3.0, -4.0])
.unwrap();
let r = where_32::<f32>(&cond, &x, &y, 4, &d).unwrap();
assert_eq!(
d.stream().clone_dtoh(&r).unwrap(),
vec![10.0f32, -2.0, 30.0, -4.0]
);
}
#[test]
fn masked_select_32_compacts() {
let d = dev();
let input = d
.stream()
.clone_htod(&vec![1.0f32, 2.0, 3.0, 4.0, 5.0])
.unwrap();
let mask = d.stream().clone_htod(&vec![1u8, 0, 1, 1, 0]).unwrap();
let (out, len) = masked_select_32::<f32>(&input, &mask, &d).unwrap();
assert_eq!(len, 3);
assert_eq!(d.stream().clone_dtoh(&out).unwrap(), vec![1.0f32, 3.0, 4.0]);
}
#[test]
fn masked_scatter_32_is_inverse_of_compact() {
let d = dev();
let mask = d.stream().clone_htod(&vec![1u8, 0, 1, 1, 0]).unwrap();
let grad = d.stream().clone_htod(&vec![10.0f32, 30.0, 40.0]).unwrap();
let out = masked_scatter_32::<f32>(&grad, &mask, 5, &d).unwrap();
assert_eq!(
d.stream().clone_dtoh(&out).unwrap(),
vec![10.0f32, 0.0, 30.0, 40.0, 0.0]
);
}
#[test]
fn masked_scatter_forward_32_keeps_input_where_false() {
let d = dev();
let input = d.stream().clone_htod(&vec![1.0f32, 2.0, 3.0, 4.0]).unwrap();
let mask = d.stream().clone_htod(&vec![0u8, 1, 1, 0]).unwrap();
let source = d.stream().clone_htod(&vec![-1.0f32, -2.0]).unwrap();
let out = masked_scatter_forward_32::<f32>(&input, &source, &mask, 4, &d).unwrap();
assert_eq!(
d.stream().clone_dtoh(&out).unwrap(),
vec![1.0f32, -1.0, -2.0, 4.0]
);
}
#[test]
fn masked_scatter_forward_64_all_false_and_all_true() {
let d = dev();
let input = d.stream().clone_htod(&vec![1.0f64, 2.0, 3.0]).unwrap();
let mask_f = d.stream().clone_htod(&vec![0u8, 0, 0]).unwrap();
let src_f = d.stream().clone_htod(&vec![9.0f64]).unwrap();
let out_f = masked_scatter_forward_64::<f64>(&input, &src_f, &mask_f, 3, &d).unwrap();
assert_eq!(
d.stream().clone_dtoh(&out_f).unwrap(),
vec![1.0f64, 2.0, 3.0]
);
let mask_t = d.stream().clone_htod(&vec![1u8, 1, 1]).unwrap();
let src_t = d.stream().clone_htod(&vec![-7.0f64, -8.0, -9.0]).unwrap();
let out_t = masked_scatter_forward_64::<f64>(&input, &src_t, &mask_t, 3, &d).unwrap();
assert_eq!(
d.stream().clone_dtoh(&out_t).unwrap(),
vec![-7.0f64, -8.0, -9.0]
);
}
#[test]
fn masked_scatter_16_bf16_bits_roundtrip() {
let d = dev();
let one = half::bf16::from_f32(1.0).to_bits();
let two = half::bf16::from_f32(2.0).to_bits();
let mask = d.stream().clone_htod(&vec![0u8, 1, 1]).unwrap();
let grad = d.stream().clone_htod(&vec![one, two]).unwrap();
let out = masked_scatter_16(&grad, &mask, 3, &d).unwrap();
assert_eq!(d.stream().clone_dtoh(&out).unwrap(), vec![0u16, one, two]);
}
#[test]
fn count_true_counts() {
let d = dev();
let mask = d.stream().clone_htod(&vec![1u8, 0, 1, 1, 0, 1]).unwrap();
assert_eq!(count_true(&mask, &d).unwrap(), 4);
let empty: Vec<u8> = vec![];
let m0 = d.stream().clone_htod(&empty).unwrap();
assert_eq!(count_true(&m0, &d).unwrap(), 0);
}
#[test]
fn isfinite_mask_f32_matches_ieee() {
let d = dev();
let input = d
.stream()
.clone_htod(&vec![
1.0f32,
f32::NAN,
3.0,
f32::INFINITY,
f32::NEG_INFINITY,
-2.5,
])
.unwrap();
let mask = isfinite_mask_f32(&input, &d).unwrap();
assert_eq!(
d.stream().clone_dtoh(&mask).unwrap(),
vec![1u8, 0, 1, 0, 0, 1]
);
}
#[test]
fn isfinite_mask_f64_matches_ieee() {
let d = dev();
let host = vec![1.0f64, f64::NAN, f64::INFINITY, 0.0, f64::NEG_INFINITY];
let input = d.stream().clone_htod(&host).unwrap();
let mask = isfinite_mask_f64(&input, &d).unwrap();
let expected: Vec<u8> = host.iter().map(|v| u8::from(v.is_finite())).collect();
assert_eq!(d.stream().clone_dtoh(&mask).unwrap(), expected);
}
#[test]
fn ne_scalar_mask_f32_marks_unequal() {
let d = dev();
let input = d.stream().clone_htod(&vec![1.0f32, 5.0, 5.0, 2.0]).unwrap();
let mask = ne_scalar_mask_f32(&input, 5.0, &d).unwrap();
assert_eq!(d.stream().clone_dtoh(&mask).unwrap(), vec![1u8, 0, 0, 1]);
}
#[test]
fn ne_scalar_mask_f64_nan_is_unequal() {
let d = dev();
let input = d.stream().clone_htod(&vec![5.0f64, f64::NAN, 5.0]).unwrap();
let mask = ne_scalar_mask_f64(&input, 5.0, &d).unwrap();
assert_eq!(d.stream().clone_dtoh(&mask).unwrap(), vec![0u8, 1, 0]);
}
}