#![cfg(feature = "cuda")]
use cudarc::driver::{CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg, ValidAsZeroBits};
use crate::device::GpuDevice;
use crate::error::{GpuError, GpuResult};
use crate::module_cache::get_or_compile;
const BLOCK_SIZE: u32 = 256;
fn launch_1d(n: usize) -> LaunchConfig {
let grid = ((n as u32).saturating_add(BLOCK_SIZE - 1)) / BLOCK_SIZE;
LaunchConfig {
grid_dim: (grid.max(1), 1, 1),
block_dim: (BLOCK_SIZE, 1, 1),
shared_mem_bytes: 0,
}
}
const ADD_I32_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry add_i32_kernel(
.param .u64 a_ptr, .param .u64 b_ptr, .param .u64 out_ptr, .param .u32 n
) {
.reg .u32 %idx, %bid, %bdim, %nr;
.reg .u64 %a, %b, %out, %off;
.reg .s32 %va, %vb, %vr;
.reg .pred %p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %b, [b_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %nr, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %idx, %tid.x;
mad.lo.u32 %idx, %bid, %bdim, %idx;
setp.ge.u32 %p, %idx, %nr;
@%p bra DONE;
cvt.u64.u32 %off, %idx;
shl.b64 %off, %off, 2;
add.u64 %a, %a, %off;
add.u64 %b, %b, %off;
add.u64 %out, %out, %off;
ld.global.s32 %va, [%a];
ld.global.s32 %vb, [%b];
add.s32 %vr, %va, %vb;
st.global.s32 [%out], %vr;
DONE:
ret;
}
";
const SUB_I32_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry sub_i32_kernel(
.param .u64 a_ptr, .param .u64 b_ptr, .param .u64 out_ptr, .param .u32 n
) {
.reg .u32 %idx, %bid, %bdim, %nr;
.reg .u64 %a, %b, %out, %off;
.reg .s32 %va, %vb, %vr;
.reg .pred %p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %b, [b_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %nr, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %idx, %tid.x;
mad.lo.u32 %idx, %bid, %bdim, %idx;
setp.ge.u32 %p, %idx, %nr;
@%p bra DONE;
cvt.u64.u32 %off, %idx;
shl.b64 %off, %off, 2;
add.u64 %a, %a, %off;
add.u64 %b, %b, %off;
add.u64 %out, %out, %off;
ld.global.s32 %va, [%a];
ld.global.s32 %vb, [%b];
sub.s32 %vr, %va, %vb;
st.global.s32 [%out], %vr;
DONE:
ret;
}
";
const MUL_I32_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry mul_i32_kernel(
.param .u64 a_ptr, .param .u64 b_ptr, .param .u64 out_ptr, .param .u32 n
) {
.reg .u32 %idx, %bid, %bdim, %nr;
.reg .u64 %a, %b, %out, %off;
.reg .s32 %va, %vb, %vr;
.reg .pred %p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %b, [b_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %nr, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %idx, %tid.x;
mad.lo.u32 %idx, %bid, %bdim, %idx;
setp.ge.u32 %p, %idx, %nr;
@%p bra DONE;
cvt.u64.u32 %off, %idx;
shl.b64 %off, %off, 2;
add.u64 %a, %a, %off;
add.u64 %b, %b, %off;
add.u64 %out, %out, %off;
ld.global.s32 %va, [%a];
ld.global.s32 %vb, [%b];
mul.lo.s32 %vr, %va, %vb;
st.global.s32 [%out], %vr;
DONE:
ret;
}
";
const FLOORDIV_I32_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry floordiv_i32_kernel(
.param .u64 a_ptr, .param .u64 b_ptr, .param .u64 out_ptr, .param .u32 n
) {
.reg .u32 %idx, %bid, %bdim, %nr;
.reg .u64 %a, %b, %out, %off;
.reg .s32 %va, %vb, %q, %r, %qm1, %zero;
.reg .pred %p, %rnz, %rneg, %bneg, %diff;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %b, [b_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %nr, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %idx, %tid.x;
mad.lo.u32 %idx, %bid, %bdim, %idx;
setp.ge.u32 %p, %idx, %nr;
@%p bra DONE;
cvt.u64.u32 %off, %idx;
shl.b64 %off, %off, 2;
add.u64 %a, %a, %off;
add.u64 %b, %b, %off;
add.u64 %out, %out, %off;
ld.global.s32 %va, [%a];
ld.global.s32 %vb, [%b];
mov.s32 %zero, 0;
div.s32 %q, %va, %vb;
rem.s32 %r, %va, %vb;
// diff = (r < 0) XOR (b < 0); correction when r != 0 && diff.
setp.ne.s32 %rnz, %r, %zero;
setp.lt.s32 %rneg, %r, %zero;
setp.lt.s32 %bneg, %vb, %zero;
xor.pred %diff, %rneg, %bneg;
and.pred %diff, %diff, %rnz;
sub.s32 %qm1, %q, 1;
selp.s32 %q, %qm1, %q, %diff;
st.global.s32 [%out], %q;
DONE:
ret;
}
";
const REMAINDER_I32_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry remainder_i32_kernel(
.param .u64 a_ptr, .param .u64 b_ptr, .param .u64 out_ptr, .param .u32 n
) {
.reg .u32 %idx, %bid, %bdim, %nr;
.reg .u64 %a, %b, %out, %off;
.reg .s32 %va, %vb, %r, %rpb, %zero;
.reg .pred %p, %rnz, %rneg, %bneg, %diff;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %b, [b_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %nr, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %idx, %tid.x;
mad.lo.u32 %idx, %bid, %bdim, %idx;
setp.ge.u32 %p, %idx, %nr;
@%p bra DONE;
cvt.u64.u32 %off, %idx;
shl.b64 %off, %off, 2;
add.u64 %a, %a, %off;
add.u64 %b, %b, %off;
add.u64 %out, %out, %off;
ld.global.s32 %va, [%a];
ld.global.s32 %vb, [%b];
mov.s32 %zero, 0;
rem.s32 %r, %va, %vb;
setp.ne.s32 %rnz, %r, %zero;
setp.lt.s32 %rneg, %r, %zero;
setp.lt.s32 %bneg, %vb, %zero;
xor.pred %diff, %rneg, %bneg;
and.pred %diff, %diff, %rnz;
add.s32 %rpb, %r, %vb;
selp.s32 %r, %rpb, %r, %diff;
st.global.s32 [%out], %r;
DONE:
ret;
}
";
const BITAND_I32_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry bitand_i32_kernel(
.param .u64 a_ptr, .param .u64 b_ptr, .param .u64 out_ptr, .param .u32 n
) {
.reg .u32 %idx, %bid, %bdim, %nr;
.reg .u64 %a, %b, %out, %off;
.reg .b32 %va, %vb, %vr;
.reg .pred %p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %b, [b_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %nr, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %idx, %tid.x;
mad.lo.u32 %idx, %bid, %bdim, %idx;
setp.ge.u32 %p, %idx, %nr;
@%p bra DONE;
cvt.u64.u32 %off, %idx;
shl.b64 %off, %off, 2;
add.u64 %a, %a, %off;
add.u64 %b, %b, %off;
add.u64 %out, %out, %off;
ld.global.b32 %va, [%a];
ld.global.b32 %vb, [%b];
and.b32 %vr, %va, %vb;
st.global.b32 [%out], %vr;
DONE:
ret;
}
";
const BITOR_I32_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry bitor_i32_kernel(
.param .u64 a_ptr, .param .u64 b_ptr, .param .u64 out_ptr, .param .u32 n
) {
.reg .u32 %idx, %bid, %bdim, %nr;
.reg .u64 %a, %b, %out, %off;
.reg .b32 %va, %vb, %vr;
.reg .pred %p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %b, [b_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %nr, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %idx, %tid.x;
mad.lo.u32 %idx, %bid, %bdim, %idx;
setp.ge.u32 %p, %idx, %nr;
@%p bra DONE;
cvt.u64.u32 %off, %idx;
shl.b64 %off, %off, 2;
add.u64 %a, %a, %off;
add.u64 %b, %b, %off;
add.u64 %out, %out, %off;
ld.global.b32 %va, [%a];
ld.global.b32 %vb, [%b];
or.b32 %vr, %va, %vb;
st.global.b32 [%out], %vr;
DONE:
ret;
}
";
const BITXOR_I32_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry bitxor_i32_kernel(
.param .u64 a_ptr, .param .u64 b_ptr, .param .u64 out_ptr, .param .u32 n
) {
.reg .u32 %idx, %bid, %bdim, %nr;
.reg .u64 %a, %b, %out, %off;
.reg .b32 %va, %vb, %vr;
.reg .pred %p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %b, [b_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %nr, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %idx, %tid.x;
mad.lo.u32 %idx, %bid, %bdim, %idx;
setp.ge.u32 %p, %idx, %nr;
@%p bra DONE;
cvt.u64.u32 %off, %idx;
shl.b64 %off, %off, 2;
add.u64 %a, %a, %off;
add.u64 %b, %b, %off;
add.u64 %out, %out, %off;
ld.global.b32 %va, [%a];
ld.global.b32 %vb, [%b];
xor.b32 %vr, %va, %vb;
st.global.b32 [%out], %vr;
DONE:
ret;
}
";
const SHL_I32_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry shl_i32_kernel(
.param .u64 a_ptr, .param .u64 b_ptr, .param .u64 out_ptr, .param .u32 n
) {
.reg .u32 %idx, %bid, %bdim, %nr, %sh;
.reg .u64 %a, %b, %out, %off;
.reg .b32 %va, %vr;
.reg .pred %p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %b, [b_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %nr, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %idx, %tid.x;
mad.lo.u32 %idx, %bid, %bdim, %idx;
setp.ge.u32 %p, %idx, %nr;
@%p bra DONE;
cvt.u64.u32 %off, %idx;
shl.b64 %off, %off, 2;
add.u64 %a, %a, %off;
add.u64 %b, %b, %off;
add.u64 %out, %out, %off;
ld.global.b32 %va, [%a];
ld.global.u32 %sh, [%b];
shl.b32 %vr, %va, %sh;
st.global.b32 [%out], %vr;
DONE:
ret;
}
";
const SHR_I32_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry shr_i32_kernel(
.param .u64 a_ptr, .param .u64 b_ptr, .param .u64 out_ptr, .param .u32 n
) {
.reg .u32 %idx, %bid, %bdim, %nr, %sh;
.reg .u64 %a, %b, %out, %off;
.reg .s32 %va, %vr;
.reg .pred %p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %b, [b_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %nr, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %idx, %tid.x;
mad.lo.u32 %idx, %bid, %bdim, %idx;
setp.ge.u32 %p, %idx, %nr;
@%p bra DONE;
cvt.u64.u32 %off, %idx;
shl.b64 %off, %off, 2;
add.u64 %a, %a, %off;
add.u64 %b, %b, %off;
add.u64 %out, %out, %off;
ld.global.s32 %va, [%a];
ld.global.u32 %sh, [%b];
shr.s32 %vr, %va, %sh;
st.global.s32 [%out], %vr;
DONE:
ret;
}
";
const NEG_I32_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry neg_i32_kernel(
.param .u64 a_ptr, .param .u64 out_ptr, .param .u32 n
) {
.reg .u32 %idx, %bid, %bdim, %nr;
.reg .u64 %a, %out, %off;
.reg .s32 %va, %vr, %zero;
.reg .pred %p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %nr, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %idx, %tid.x;
mad.lo.u32 %idx, %bid, %bdim, %idx;
setp.ge.u32 %p, %idx, %nr;
@%p bra DONE;
cvt.u64.u32 %off, %idx;
shl.b64 %off, %off, 2;
add.u64 %a, %a, %off;
add.u64 %out, %out, %off;
ld.global.s32 %va, [%a];
mov.s32 %zero, 0;
sub.s32 %vr, %zero, %va;
st.global.s32 [%out], %vr;
DONE:
ret;
}
";
const BITNOT_I32_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry bitnot_i32_kernel(
.param .u64 a_ptr, .param .u64 out_ptr, .param .u32 n
) {
.reg .u32 %idx, %bid, %bdim, %nr;
.reg .u64 %a, %out, %off;
.reg .b32 %va, %vr;
.reg .pred %p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %nr, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %idx, %tid.x;
mad.lo.u32 %idx, %bid, %bdim, %idx;
setp.ge.u32 %p, %idx, %nr;
@%p bra DONE;
cvt.u64.u32 %off, %idx;
shl.b64 %off, %off, 2;
add.u64 %a, %a, %off;
add.u64 %out, %out, %off;
ld.global.b32 %va, [%a];
not.b32 %vr, %va;
st.global.b32 [%out], %vr;
DONE:
ret;
}
";
const REDUCE_I32_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry reduce_i32_kernel(
.param .u64 a_ptr, .param .u64 out_ptr, .param .u32 n, .param .u32 op
) {
.reg .u32 %idx, %bid, %bdim, %nr, %op_r, %i;
.reg .u64 %a, %out, %off, %cur;
.reg .s32 %acc, %v, %prod, %mn, %mx;
.reg .pred %p, %only0, %is_sum, %is_prod, %is_min, %lt, %gt;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %nr, [n];
ld.param.u32 %op_r, [op];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %idx, %tid.x;
mad.lo.u32 %idx, %bid, %bdim, %idx;
// Only thread 0 performs the reduction.
setp.ne.u32 %only0, %idx, 0;
@%only0 bra DONE;
setp.eq.u32 %is_sum, %op_r, 0;
setp.eq.u32 %is_prod, %op_r, 1;
setp.eq.u32 %is_min, %op_r, 2;
// Initialise accumulator from a[0] (n >= 1 guaranteed by the host).
ld.global.s32 %acc, [%a];
mov.u32 %i, 1;
LOOP:
setp.ge.u32 %p, %i, %nr;
@%p bra STORE;
cvt.u64.u32 %off, %i;
shl.b64 %off, %off, 2;
add.u64 %cur, %a, %off;
ld.global.s32 %v, [%cur];
// sum
@%is_sum add.s32 %acc, %acc, %v;
// prod
mul.lo.s32 %prod, %acc, %v;
@%is_prod mov.s32 %acc, %prod;
// min
setp.lt.s32 %lt, %v, %acc;
@%is_min selp.s32 %acc, %v, %acc, %lt;
// max (the remaining op): only update when op==3.
setp.gt.s32 %gt, %v, %acc;
@%is_sum bra SKIPMAX;
@%is_prod bra SKIPMAX;
@%is_min bra SKIPMAX;
selp.s32 %acc, %v, %acc, %gt;
SKIPMAX:
add.u32 %i, %i, 1;
bra LOOP;
STORE:
st.global.s32 [%out], %acc;
DONE:
ret;
}
";
const ADD_I64_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry add_i64_kernel(
.param .u64 a_ptr, .param .u64 b_ptr, .param .u64 out_ptr, .param .u32 n
) {
.reg .u32 %idx, %bid, %bdim, %nr;
.reg .u64 %a, %b, %out, %off;
.reg .s64 %va, %vb, %vr;
.reg .pred %p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %b, [b_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %nr, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %idx, %tid.x;
mad.lo.u32 %idx, %bid, %bdim, %idx;
setp.ge.u32 %p, %idx, %nr;
@%p bra DONE;
cvt.u64.u32 %off, %idx;
shl.b64 %off, %off, 3;
add.u64 %a, %a, %off;
add.u64 %b, %b, %off;
add.u64 %out, %out, %off;
ld.global.s64 %va, [%a];
ld.global.s64 %vb, [%b];
add.s64 %vr, %va, %vb;
st.global.s64 [%out], %vr;
DONE:
ret;
}
";
const SUB_I64_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry sub_i64_kernel(
.param .u64 a_ptr, .param .u64 b_ptr, .param .u64 out_ptr, .param .u32 n
) {
.reg .u32 %idx, %bid, %bdim, %nr;
.reg .u64 %a, %b, %out, %off;
.reg .s64 %va, %vb, %vr;
.reg .pred %p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %b, [b_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %nr, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %idx, %tid.x;
mad.lo.u32 %idx, %bid, %bdim, %idx;
setp.ge.u32 %p, %idx, %nr;
@%p bra DONE;
cvt.u64.u32 %off, %idx;
shl.b64 %off, %off, 3;
add.u64 %a, %a, %off;
add.u64 %b, %b, %off;
add.u64 %out, %out, %off;
ld.global.s64 %va, [%a];
ld.global.s64 %vb, [%b];
sub.s64 %vr, %va, %vb;
st.global.s64 [%out], %vr;
DONE:
ret;
}
";
const MUL_I64_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry mul_i64_kernel(
.param .u64 a_ptr, .param .u64 b_ptr, .param .u64 out_ptr, .param .u32 n
) {
.reg .u32 %idx, %bid, %bdim, %nr;
.reg .u64 %a, %b, %out, %off;
.reg .s64 %va, %vb, %vr;
.reg .pred %p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %b, [b_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %nr, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %idx, %tid.x;
mad.lo.u32 %idx, %bid, %bdim, %idx;
setp.ge.u32 %p, %idx, %nr;
@%p bra DONE;
cvt.u64.u32 %off, %idx;
shl.b64 %off, %off, 3;
add.u64 %a, %a, %off;
add.u64 %b, %b, %off;
add.u64 %out, %out, %off;
ld.global.s64 %va, [%a];
ld.global.s64 %vb, [%b];
mul.lo.s64 %vr, %va, %vb;
st.global.s64 [%out], %vr;
DONE:
ret;
}
";
const FLOORDIV_I64_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry floordiv_i64_kernel(
.param .u64 a_ptr, .param .u64 b_ptr, .param .u64 out_ptr, .param .u32 n
) {
.reg .u32 %idx, %bid, %bdim, %nr;
.reg .u64 %a, %b, %out, %off;
.reg .s64 %va, %vb, %q, %r, %qm1, %zero;
.reg .pred %p, %rnz, %rneg, %bneg, %diff;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %b, [b_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %nr, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %idx, %tid.x;
mad.lo.u32 %idx, %bid, %bdim, %idx;
setp.ge.u32 %p, %idx, %nr;
@%p bra DONE;
cvt.u64.u32 %off, %idx;
shl.b64 %off, %off, 3;
add.u64 %a, %a, %off;
add.u64 %b, %b, %off;
add.u64 %out, %out, %off;
ld.global.s64 %va, [%a];
ld.global.s64 %vb, [%b];
mov.s64 %zero, 0;
div.s64 %q, %va, %vb;
rem.s64 %r, %va, %vb;
setp.ne.s64 %rnz, %r, %zero;
setp.lt.s64 %rneg, %r, %zero;
setp.lt.s64 %bneg, %vb, %zero;
xor.pred %diff, %rneg, %bneg;
and.pred %diff, %diff, %rnz;
sub.s64 %qm1, %q, 1;
selp.s64 %q, %qm1, %q, %diff;
st.global.s64 [%out], %q;
DONE:
ret;
}
";
const REMAINDER_I64_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry remainder_i64_kernel(
.param .u64 a_ptr, .param .u64 b_ptr, .param .u64 out_ptr, .param .u32 n
) {
.reg .u32 %idx, %bid, %bdim, %nr;
.reg .u64 %a, %b, %out, %off;
.reg .s64 %va, %vb, %r, %rpb, %zero;
.reg .pred %p, %rnz, %rneg, %bneg, %diff;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %b, [b_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %nr, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %idx, %tid.x;
mad.lo.u32 %idx, %bid, %bdim, %idx;
setp.ge.u32 %p, %idx, %nr;
@%p bra DONE;
cvt.u64.u32 %off, %idx;
shl.b64 %off, %off, 3;
add.u64 %a, %a, %off;
add.u64 %b, %b, %off;
add.u64 %out, %out, %off;
ld.global.s64 %va, [%a];
ld.global.s64 %vb, [%b];
mov.s64 %zero, 0;
rem.s64 %r, %va, %vb;
setp.ne.s64 %rnz, %r, %zero;
setp.lt.s64 %rneg, %r, %zero;
setp.lt.s64 %bneg, %vb, %zero;
xor.pred %diff, %rneg, %bneg;
and.pred %diff, %diff, %rnz;
add.s64 %rpb, %r, %vb;
selp.s64 %r, %rpb, %r, %diff;
st.global.s64 [%out], %r;
DONE:
ret;
}
";
const BITAND_I64_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry bitand_i64_kernel(
.param .u64 a_ptr, .param .u64 b_ptr, .param .u64 out_ptr, .param .u32 n
) {
.reg .u32 %idx, %bid, %bdim, %nr;
.reg .u64 %a, %b, %out, %off;
.reg .b64 %va, %vb, %vr;
.reg .pred %p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %b, [b_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %nr, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %idx, %tid.x;
mad.lo.u32 %idx, %bid, %bdim, %idx;
setp.ge.u32 %p, %idx, %nr;
@%p bra DONE;
cvt.u64.u32 %off, %idx;
shl.b64 %off, %off, 3;
add.u64 %a, %a, %off;
add.u64 %b, %b, %off;
add.u64 %out, %out, %off;
ld.global.b64 %va, [%a];
ld.global.b64 %vb, [%b];
and.b64 %vr, %va, %vb;
st.global.b64 [%out], %vr;
DONE:
ret;
}
";
const BITOR_I64_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry bitor_i64_kernel(
.param .u64 a_ptr, .param .u64 b_ptr, .param .u64 out_ptr, .param .u32 n
) {
.reg .u32 %idx, %bid, %bdim, %nr;
.reg .u64 %a, %b, %out, %off;
.reg .b64 %va, %vb, %vr;
.reg .pred %p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %b, [b_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %nr, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %idx, %tid.x;
mad.lo.u32 %idx, %bid, %bdim, %idx;
setp.ge.u32 %p, %idx, %nr;
@%p bra DONE;
cvt.u64.u32 %off, %idx;
shl.b64 %off, %off, 3;
add.u64 %a, %a, %off;
add.u64 %b, %b, %off;
add.u64 %out, %out, %off;
ld.global.b64 %va, [%a];
ld.global.b64 %vb, [%b];
or.b64 %vr, %va, %vb;
st.global.b64 [%out], %vr;
DONE:
ret;
}
";
const BITXOR_I64_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry bitxor_i64_kernel(
.param .u64 a_ptr, .param .u64 b_ptr, .param .u64 out_ptr, .param .u32 n
) {
.reg .u32 %idx, %bid, %bdim, %nr;
.reg .u64 %a, %b, %out, %off;
.reg .b64 %va, %vb, %vr;
.reg .pred %p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %b, [b_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %nr, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %idx, %tid.x;
mad.lo.u32 %idx, %bid, %bdim, %idx;
setp.ge.u32 %p, %idx, %nr;
@%p bra DONE;
cvt.u64.u32 %off, %idx;
shl.b64 %off, %off, 3;
add.u64 %a, %a, %off;
add.u64 %b, %b, %off;
add.u64 %out, %out, %off;
ld.global.b64 %va, [%a];
ld.global.b64 %vb, [%b];
xor.b64 %vr, %va, %vb;
st.global.b64 [%out], %vr;
DONE:
ret;
}
";
const SHL_I64_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry shl_i64_kernel(
.param .u64 a_ptr, .param .u64 b_ptr, .param .u64 out_ptr, .param .u32 n
) {
.reg .u32 %idx, %bid, %bdim, %nr, %sh;
.reg .u64 %a, %b, %out, %off;
.reg .b64 %va, %vr;
.reg .s64 %vb;
.reg .pred %p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %b, [b_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %nr, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %idx, %tid.x;
mad.lo.u32 %idx, %bid, %bdim, %idx;
setp.ge.u32 %p, %idx, %nr;
@%p bra DONE;
cvt.u64.u32 %off, %idx;
shl.b64 %off, %off, 3;
add.u64 %a, %a, %off;
add.u64 %b, %b, %off;
add.u64 %out, %out, %off;
ld.global.b64 %va, [%a];
ld.global.s64 %vb, [%b];
cvt.u32.u64 %sh, %vb;
shl.b64 %vr, %va, %sh;
st.global.b64 [%out], %vr;
DONE:
ret;
}
";
const SHR_I64_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry shr_i64_kernel(
.param .u64 a_ptr, .param .u64 b_ptr, .param .u64 out_ptr, .param .u32 n
) {
.reg .u32 %idx, %bid, %bdim, %nr, %sh;
.reg .u64 %a, %b, %out, %off;
.reg .s64 %va, %vb, %vr;
.reg .pred %p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %b, [b_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %nr, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %idx, %tid.x;
mad.lo.u32 %idx, %bid, %bdim, %idx;
setp.ge.u32 %p, %idx, %nr;
@%p bra DONE;
cvt.u64.u32 %off, %idx;
shl.b64 %off, %off, 3;
add.u64 %a, %a, %off;
add.u64 %b, %b, %off;
add.u64 %out, %out, %off;
ld.global.s64 %va, [%a];
ld.global.s64 %vb, [%b];
cvt.u32.u64 %sh, %vb;
shr.s64 %vr, %va, %sh;
st.global.s64 [%out], %vr;
DONE:
ret;
}
";
const NEG_I64_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry neg_i64_kernel(
.param .u64 a_ptr, .param .u64 out_ptr, .param .u32 n
) {
.reg .u32 %idx, %bid, %bdim, %nr;
.reg .u64 %a, %out, %off;
.reg .s64 %va, %vr, %zero;
.reg .pred %p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %nr, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %idx, %tid.x;
mad.lo.u32 %idx, %bid, %bdim, %idx;
setp.ge.u32 %p, %idx, %nr;
@%p bra DONE;
cvt.u64.u32 %off, %idx;
shl.b64 %off, %off, 3;
add.u64 %a, %a, %off;
add.u64 %out, %out, %off;
ld.global.s64 %va, [%a];
mov.s64 %zero, 0;
sub.s64 %vr, %zero, %va;
st.global.s64 [%out], %vr;
DONE:
ret;
}
";
const BITNOT_I64_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry bitnot_i64_kernel(
.param .u64 a_ptr, .param .u64 out_ptr, .param .u32 n
) {
.reg .u32 %idx, %bid, %bdim, %nr;
.reg .u64 %a, %out, %off;
.reg .b64 %va, %vr;
.reg .pred %p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %nr, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %idx, %tid.x;
mad.lo.u32 %idx, %bid, %bdim, %idx;
setp.ge.u32 %p, %idx, %nr;
@%p bra DONE;
cvt.u64.u32 %off, %idx;
shl.b64 %off, %off, 3;
add.u64 %a, %a, %off;
add.u64 %out, %out, %off;
ld.global.b64 %va, [%a];
not.b64 %vr, %va;
st.global.b64 [%out], %vr;
DONE:
ret;
}
";
const REDUCE_I64_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry reduce_i64_kernel(
.param .u64 a_ptr, .param .u64 out_ptr, .param .u32 n, .param .u32 op
) {
.reg .u32 %idx, %bid, %bdim, %nr, %op_r, %i;
.reg .u64 %a, %out, %off, %cur;
.reg .s64 %acc, %v, %prod;
.reg .pred %p, %only0, %is_sum, %is_prod, %is_min, %lt, %gt;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %nr, [n];
ld.param.u32 %op_r, [op];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %idx, %tid.x;
mad.lo.u32 %idx, %bid, %bdim, %idx;
setp.ne.u32 %only0, %idx, 0;
@%only0 bra DONE;
setp.eq.u32 %is_sum, %op_r, 0;
setp.eq.u32 %is_prod, %op_r, 1;
setp.eq.u32 %is_min, %op_r, 2;
ld.global.s64 %acc, [%a];
mov.u32 %i, 1;
LOOP:
setp.ge.u32 %p, %i, %nr;
@%p bra STORE;
cvt.u64.u32 %off, %i;
shl.b64 %off, %off, 3;
add.u64 %cur, %a, %off;
ld.global.s64 %v, [%cur];
@%is_sum add.s64 %acc, %acc, %v;
mul.lo.s64 %prod, %acc, %v;
@%is_prod mov.s64 %acc, %prod;
setp.lt.s64 %lt, %v, %acc;
@%is_min selp.s64 %acc, %v, %acc, %lt;
setp.gt.s64 %gt, %v, %acc;
@%is_sum bra SKIPMAX;
@%is_prod bra SKIPMAX;
@%is_min bra SKIPMAX;
selp.s64 %acc, %v, %acc, %gt;
SKIPMAX:
add.u32 %i, %i, 1;
bra LOOP;
STORE:
st.global.s64 [%out], %acc;
DONE:
ret;
}
";
const REDUCE_SUM: u32 = 0;
const REDUCE_PROD: u32 = 1;
const REDUCE_MIN: u32 = 2;
const REDUCE_MAX: u32 = 3;
fn launch_binary<T: DeviceRepr + ValidAsZeroBits>(
a: &CudaSlice<T>,
b: &CudaSlice<T>,
device: &GpuDevice,
ptx: &'static str,
kernel_name: &'static str,
) -> GpuResult<CudaSlice<T>> {
if a.len() != b.len() {
return Err(GpuError::LengthMismatch {
a: a.len(),
b: b.len(),
});
}
let n = a.len();
let stream = device.stream();
if n == 0 {
return Ok(stream.alloc_zeros::<T>(0)?);
}
let ctx = device.context();
let f = get_or_compile(ctx, ptx, kernel_name, device.ordinal() as u32).map_err(|e| {
GpuError::PtxCompileFailed {
kernel: kernel_name,
source: e,
}
})?;
let mut out = stream.alloc_zeros::<T>(n)?;
let cfg = launch_1d(n);
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(a)
.arg(b)
.arg(&mut out)
.arg(&n_u32)
.launch(cfg)?;
}
Ok(out)
}
fn launch_unary<T: DeviceRepr + ValidAsZeroBits>(
a: &CudaSlice<T>,
device: &GpuDevice,
ptx: &'static str,
kernel_name: &'static str,
) -> GpuResult<CudaSlice<T>> {
let n = a.len();
let stream = device.stream();
if n == 0 {
return Ok(stream.alloc_zeros::<T>(0)?);
}
let ctx = device.context();
let f = get_or_compile(ctx, ptx, kernel_name, device.ordinal() as u32).map_err(|e| {
GpuError::PtxCompileFailed {
kernel: kernel_name,
source: e,
}
})?;
let mut out = stream.alloc_zeros::<T>(n)?;
let cfg = launch_1d(n);
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(a)
.arg(&mut out)
.arg(&n_u32)
.launch(cfg)?;
}
Ok(out)
}
fn launch_reduce<T: DeviceRepr + ValidAsZeroBits>(
a: &CudaSlice<T>,
device: &GpuDevice,
ptx: &'static str,
kernel_name: &'static str,
op: u32,
empty_identity: T,
) -> GpuResult<CudaSlice<T>> {
let n = a.len();
let stream = device.stream();
if n == 0 {
let host = [empty_identity];
return Ok(stream.clone_htod(&host)?);
}
let ctx = device.context();
let f = get_or_compile(ctx, ptx, kernel_name, device.ordinal() as u32).map_err(|e| {
GpuError::PtxCompileFailed {
kernel: kernel_name,
source: e,
}
})?;
let mut out = stream.alloc_zeros::<T>(1)?;
let cfg = LaunchConfig {
grid_dim: (1, 1, 1),
block_dim: (1, 1, 1),
shared_mem_bytes: 0,
};
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(a)
.arg(&mut out)
.arg(&n_u32)
.arg(&op)
.launch(cfg)?;
}
Ok(out)
}
pub fn gpu_add_i32(
a: &CudaSlice<i32>,
b: &CudaSlice<i32>,
d: &GpuDevice,
) -> GpuResult<CudaSlice<i32>> {
launch_binary(a, b, d, ADD_I32_PTX, "add_i32_kernel")
}
pub fn gpu_sub_i32(
a: &CudaSlice<i32>,
b: &CudaSlice<i32>,
d: &GpuDevice,
) -> GpuResult<CudaSlice<i32>> {
launch_binary(a, b, d, SUB_I32_PTX, "sub_i32_kernel")
}
pub fn gpu_mul_i32(
a: &CudaSlice<i32>,
b: &CudaSlice<i32>,
d: &GpuDevice,
) -> GpuResult<CudaSlice<i32>> {
launch_binary(a, b, d, MUL_I32_PTX, "mul_i32_kernel")
}
pub fn gpu_floor_div_i32(
a: &CudaSlice<i32>,
b: &CudaSlice<i32>,
d: &GpuDevice,
) -> GpuResult<CudaSlice<i32>> {
launch_binary(a, b, d, FLOORDIV_I32_PTX, "floordiv_i32_kernel")
}
pub fn gpu_remainder_i32(
a: &CudaSlice<i32>,
b: &CudaSlice<i32>,
d: &GpuDevice,
) -> GpuResult<CudaSlice<i32>> {
launch_binary(a, b, d, REMAINDER_I32_PTX, "remainder_i32_kernel")
}
pub fn gpu_bitand_i32(
a: &CudaSlice<i32>,
b: &CudaSlice<i32>,
d: &GpuDevice,
) -> GpuResult<CudaSlice<i32>> {
launch_binary(a, b, d, BITAND_I32_PTX, "bitand_i32_kernel")
}
pub fn gpu_bitor_i32(
a: &CudaSlice<i32>,
b: &CudaSlice<i32>,
d: &GpuDevice,
) -> GpuResult<CudaSlice<i32>> {
launch_binary(a, b, d, BITOR_I32_PTX, "bitor_i32_kernel")
}
pub fn gpu_bitxor_i32(
a: &CudaSlice<i32>,
b: &CudaSlice<i32>,
d: &GpuDevice,
) -> GpuResult<CudaSlice<i32>> {
launch_binary(a, b, d, BITXOR_I32_PTX, "bitxor_i32_kernel")
}
pub fn gpu_shl_i32(
a: &CudaSlice<i32>,
b: &CudaSlice<i32>,
d: &GpuDevice,
) -> GpuResult<CudaSlice<i32>> {
launch_binary(a, b, d, SHL_I32_PTX, "shl_i32_kernel")
}
pub fn gpu_shr_i32(
a: &CudaSlice<i32>,
b: &CudaSlice<i32>,
d: &GpuDevice,
) -> GpuResult<CudaSlice<i32>> {
launch_binary(a, b, d, SHR_I32_PTX, "shr_i32_kernel")
}
pub fn gpu_neg_i32(a: &CudaSlice<i32>, d: &GpuDevice) -> GpuResult<CudaSlice<i32>> {
launch_unary(a, d, NEG_I32_PTX, "neg_i32_kernel")
}
pub fn gpu_bitnot_i32(a: &CudaSlice<i32>, d: &GpuDevice) -> GpuResult<CudaSlice<i32>> {
launch_unary(a, d, BITNOT_I32_PTX, "bitnot_i32_kernel")
}
pub fn gpu_sum_i32(a: &CudaSlice<i32>, d: &GpuDevice) -> GpuResult<CudaSlice<i32>> {
launch_reduce(a, d, REDUCE_I32_PTX, "reduce_i32_kernel", REDUCE_SUM, 0)
}
pub fn gpu_prod_i32(a: &CudaSlice<i32>, d: &GpuDevice) -> GpuResult<CudaSlice<i32>> {
launch_reduce(a, d, REDUCE_I32_PTX, "reduce_i32_kernel", REDUCE_PROD, 1)
}
pub fn gpu_min_i32(a: &CudaSlice<i32>, d: &GpuDevice) -> GpuResult<CudaSlice<i32>> {
launch_reduce(
a,
d,
REDUCE_I32_PTX,
"reduce_i32_kernel",
REDUCE_MIN,
i32::MAX,
)
}
pub fn gpu_max_i32(a: &CudaSlice<i32>, d: &GpuDevice) -> GpuResult<CudaSlice<i32>> {
launch_reduce(
a,
d,
REDUCE_I32_PTX,
"reduce_i32_kernel",
REDUCE_MAX,
i32::MIN,
)
}
pub fn gpu_add_i64(
a: &CudaSlice<i64>,
b: &CudaSlice<i64>,
d: &GpuDevice,
) -> GpuResult<CudaSlice<i64>> {
launch_binary(a, b, d, ADD_I64_PTX, "add_i64_kernel")
}
pub fn gpu_sub_i64(
a: &CudaSlice<i64>,
b: &CudaSlice<i64>,
d: &GpuDevice,
) -> GpuResult<CudaSlice<i64>> {
launch_binary(a, b, d, SUB_I64_PTX, "sub_i64_kernel")
}
pub fn gpu_mul_i64(
a: &CudaSlice<i64>,
b: &CudaSlice<i64>,
d: &GpuDevice,
) -> GpuResult<CudaSlice<i64>> {
launch_binary(a, b, d, MUL_I64_PTX, "mul_i64_kernel")
}
pub fn gpu_floor_div_i64(
a: &CudaSlice<i64>,
b: &CudaSlice<i64>,
d: &GpuDevice,
) -> GpuResult<CudaSlice<i64>> {
launch_binary(a, b, d, FLOORDIV_I64_PTX, "floordiv_i64_kernel")
}
pub fn gpu_remainder_i64(
a: &CudaSlice<i64>,
b: &CudaSlice<i64>,
d: &GpuDevice,
) -> GpuResult<CudaSlice<i64>> {
launch_binary(a, b, d, REMAINDER_I64_PTX, "remainder_i64_kernel")
}
pub fn gpu_bitand_i64(
a: &CudaSlice<i64>,
b: &CudaSlice<i64>,
d: &GpuDevice,
) -> GpuResult<CudaSlice<i64>> {
launch_binary(a, b, d, BITAND_I64_PTX, "bitand_i64_kernel")
}
pub fn gpu_bitor_i64(
a: &CudaSlice<i64>,
b: &CudaSlice<i64>,
d: &GpuDevice,
) -> GpuResult<CudaSlice<i64>> {
launch_binary(a, b, d, BITOR_I64_PTX, "bitor_i64_kernel")
}
pub fn gpu_bitxor_i64(
a: &CudaSlice<i64>,
b: &CudaSlice<i64>,
d: &GpuDevice,
) -> GpuResult<CudaSlice<i64>> {
launch_binary(a, b, d, BITXOR_I64_PTX, "bitxor_i64_kernel")
}
pub fn gpu_shl_i64(
a: &CudaSlice<i64>,
b: &CudaSlice<i64>,
d: &GpuDevice,
) -> GpuResult<CudaSlice<i64>> {
launch_binary(a, b, d, SHL_I64_PTX, "shl_i64_kernel")
}
pub fn gpu_shr_i64(
a: &CudaSlice<i64>,
b: &CudaSlice<i64>,
d: &GpuDevice,
) -> GpuResult<CudaSlice<i64>> {
launch_binary(a, b, d, SHR_I64_PTX, "shr_i64_kernel")
}
pub fn gpu_neg_i64(a: &CudaSlice<i64>, d: &GpuDevice) -> GpuResult<CudaSlice<i64>> {
launch_unary(a, d, NEG_I64_PTX, "neg_i64_kernel")
}
pub fn gpu_bitnot_i64(a: &CudaSlice<i64>, d: &GpuDevice) -> GpuResult<CudaSlice<i64>> {
launch_unary(a, d, BITNOT_I64_PTX, "bitnot_i64_kernel")
}
pub fn gpu_sum_i64(a: &CudaSlice<i64>, d: &GpuDevice) -> GpuResult<CudaSlice<i64>> {
launch_reduce(a, d, REDUCE_I64_PTX, "reduce_i64_kernel", REDUCE_SUM, 0)
}
pub fn gpu_prod_i64(a: &CudaSlice<i64>, d: &GpuDevice) -> GpuResult<CudaSlice<i64>> {
launch_reduce(a, d, REDUCE_I64_PTX, "reduce_i64_kernel", REDUCE_PROD, 1)
}
pub fn gpu_min_i64(a: &CudaSlice<i64>, d: &GpuDevice) -> GpuResult<CudaSlice<i64>> {
launch_reduce(
a,
d,
REDUCE_I64_PTX,
"reduce_i64_kernel",
REDUCE_MIN,
i64::MAX,
)
}
pub fn gpu_max_i64(a: &CudaSlice<i64>, d: &GpuDevice) -> GpuResult<CudaSlice<i64>> {
launch_reduce(
a,
d,
REDUCE_I64_PTX,
"reduce_i64_kernel",
REDUCE_MAX,
i64::MIN,
)
}