#![cfg(feature = "cuda")]
use cudarc::driver::{CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg, ValidAsZeroBits};
use crate::device::GpuDevice;
use crate::error::{GpuError, GpuResult};
use crate::module_cache::{get_or_compile, get_or_compile_owned};
const BLOCK_SIZE: u32 = 256;
fn launch_1d(n: usize) -> LaunchConfig {
let grid = ((n as u32).saturating_add(BLOCK_SIZE - 1)) / BLOCK_SIZE;
LaunchConfig {
grid_dim: (grid.max(1), 1, 1),
block_dim: (BLOCK_SIZE, 1, 1),
shared_mem_bytes: 0,
}
}
fn cmp_ptx(
kernel_name: &str,
in_shift: u32, load_ty: &str, reg_decl: &str, setp: &str, ) -> String {
format!(
"\
.version 7.0
.target sm_52
.address_size 64
.visible .entry {kernel_name}(
.param .u64 a_ptr, .param .u64 b_ptr, .param .u64 out_ptr, .param .u32 n
) {{
.reg .u32 %idx, %bid, %bdim, %nr;
.reg .u64 %a, %b, %out, %ioff, %ooff;
{reg_decl}
.reg .u16 %res;
.reg .pred %p, %c;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %b, [b_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %nr, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %idx, %tid.x;
mad.lo.u32 %idx, %bid, %bdim, %idx;
setp.ge.u32 %p, %idx, %nr;
@%p bra DONE;
cvt.u64.u32 %ioff, %idx;
shl.b64 %ioff, %ioff, {in_shift};
add.u64 %a, %a, %ioff;
add.u64 %b, %b, %ioff;
// output is 1 byte per element: out_off = idx
cvt.u64.u32 %ooff, %idx;
add.u64 %out, %out, %ooff;
ld.global.{load_ty} %va, [%a];
ld.global.{load_ty} %vb, [%b];
{setp}
selp.u16 %res, 1, 0, %c;
st.global.u8 [%out], %res;
DONE:
ret;
}}
"
)
}
fn cmp_half_ptx(
kernel_name: &str,
target: &str, decode: &str, setp: &str, ) -> String {
format!(
"\
.version 7.0
.target {target}
.address_size 64
.visible .entry {kernel_name}(
.param .u64 a_ptr, .param .u64 b_ptr, .param .u64 out_ptr, .param .u32 n
) {{
.reg .u32 %idx, %bid, %bdim, %nr;
.reg .u64 %a, %b, %out, %ioff, %ooff;
.reg .b16 %ha, %hb, %zero16;
.reg .b32 %ua, %ub;
.reg .u16 %res;
.reg .f32 %fa, %fb;
.reg .pred %p, %c;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %b, [b_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %nr, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %idx, %tid.x;
mad.lo.u32 %idx, %bid, %bdim, %idx;
setp.ge.u32 %p, %idx, %nr;
@%p bra DONE;
cvt.u64.u32 %ioff, %idx;
shl.b64 %ioff, %ioff, 1;
add.u64 %a, %a, %ioff;
add.u64 %b, %b, %ioff;
cvt.u64.u32 %ooff, %idx;
add.u64 %out, %out, %ooff;
mov.b16 %zero16, 0;
ld.global.b16 %ha, [%a];
ld.global.b16 %hb, [%b];
{decode}
{setp}
selp.u16 %res, 1, 0, %c;
st.global.u8 [%out], %res;
DONE:
ret;
}}
"
)
}
const BF16_DECODE: &str = "\
mov.b32 %ua, {%zero16, %ha}; mov.b32 %fa, %ua;
mov.b32 %ub, {%zero16, %hb}; mov.b32 %fb, %ub;";
const F16_DECODE: &str = "\
cvt.f32.f16 %fa, %ha;
cvt.f32.f16 %fb, %hb;";
fn logic_bin_ptx(kernel_name: &str, op: &str ) -> String {
format!(
"\
.version 7.0
.target sm_52
.address_size 64
.visible .entry {kernel_name}(
.param .u64 a_ptr, .param .u64 b_ptr, .param .u64 out_ptr, .param .u32 n
) {{
.reg .u32 %idx, %bid, %bdim, %nr;
.reg .u64 %a, %b, %out, %off;
.reg .u16 %va, %vb, %res;
.reg .pred %pa, %pb, %pr, %p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %b, [b_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %nr, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %idx, %tid.x;
mad.lo.u32 %idx, %bid, %bdim, %idx;
setp.ge.u32 %p, %idx, %nr;
@%p bra DONE;
cvt.u64.u32 %off, %idx;
add.u64 %a, %a, %off;
add.u64 %b, %b, %off;
add.u64 %out, %out, %off;
ld.global.u8 %va, [%a];
ld.global.u8 %vb, [%b];
setp.ne.u16 %pa, %va, 0;
setp.ne.u16 %pb, %vb, 0;
{op}.pred %pr, %pa, %pb;
selp.u16 %res, 1, 0, %pr;
st.global.u8 [%out], %res;
DONE:
ret;
}}
"
)
}
const NOT_BOOL_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry not_bool_kernel(
.param .u64 a_ptr, .param .u64 out_ptr, .param .u32 n
) {
.reg .u32 %idx, %bid, %bdim, %nr;
.reg .u64 %a, %out, %off;
.reg .u16 %va, %res;
.reg .pred %pa, %p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %nr, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %idx, %tid.x;
mad.lo.u32 %idx, %bid, %bdim, %idx;
setp.ge.u32 %p, %idx, %nr;
@%p bra DONE;
cvt.u64.u32 %off, %idx;
add.u64 %a, %a, %off;
add.u64 %out, %out, %off;
ld.global.u8 %va, [%a];
// res = (va == 0) ? 1 : 0
setp.eq.u16 %pa, %va, 0;
selp.u16 %res, 1, 0, %pa;
st.global.u8 [%out], %res;
DONE:
ret;
}
";
const REDUCE_BOOL_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry reduce_bool_kernel(
.param .u64 a_ptr, .param .u64 out_ptr, .param .u32 n, .param .u32 op
) {
.reg .u32 %idx, %bid, %bdim, %nr, %op_r, %i;
.reg .u64 %a, %out, %off, %cur;
.reg .u16 %acc, %v, %vn;
.reg .pred %only0, %p, %is_any, %pacc, %pv;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %nr, [n];
ld.param.u32 %op_r, [op];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %idx, %tid.x;
mad.lo.u32 %idx, %bid, %bdim, %idx;
setp.ne.u32 %only0, %idx, 0;
@%only0 bra DONE;
setp.eq.u32 %is_any, %op_r, 0;
// Initialise accumulator from a[0] normalised to 0/1 (n >= 1 guaranteed).
ld.global.u8 %v, [%a];
setp.ne.u16 %pv, %v, 0;
selp.u16 %acc, 1, 0, %pv;
mov.u32 %i, 1;
LOOP:
setp.ge.u32 %p, %i, %nr;
@%p bra STORE;
cvt.u64.u32 %off, %i;
add.u64 %cur, %a, %off;
ld.global.u8 %v, [%cur];
setp.ne.u16 %pv, %v, 0;
selp.u16 %vn, 1, 0, %pv;
setp.ne.u16 %pacc, %acc, 0;
setp.ne.u16 %pv, %vn, 0;
// any: acc = acc OR v ; all: acc = acc AND v
@%is_any or.pred %pacc, %pacc, %pv;
@!%is_any and.pred %pacc, %pacc, %pv;
selp.u16 %acc, 1, 0, %pacc;
add.u32 %i, %i, 1;
bra LOOP;
STORE:
st.global.u8 [%out], %acc;
DONE:
ret;
}
";
const REDUCE_ANY: u32 = 0;
const REDUCE_ALL: u32 = 1;
fn launch_cmp<T: DeviceRepr + ValidAsZeroBits>(
a: &CudaSlice<T>,
b: &CudaSlice<T>,
device: &GpuDevice,
ptx: String,
kernel_name: String,
err_label: &'static str,
) -> GpuResult<CudaSlice<u8>> {
if a.len() != b.len() {
return Err(GpuError::LengthMismatch {
a: a.len(),
b: b.len(),
});
}
let n = a.len();
let stream = device.stream();
if n == 0 {
return Ok(stream.alloc_zeros::<u8>(0)?);
}
let ctx = device.context();
let f = get_or_compile_owned(ctx, ptx, kernel_name, device.ordinal() as u32).map_err(|e| {
GpuError::PtxCompileFailed {
kernel: err_label,
source: e,
}
})?;
let mut out = stream.alloc_zeros::<u8>(n)?;
let cfg = launch_1d(n);
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(a)
.arg(b)
.arg(&mut out)
.arg(&n_u32)
.launch(cfg)?;
}
Ok(out)
}
fn launch_cmp_half(
a: &CudaSlice<u16>,
b: &CudaSlice<u16>,
device: &GpuDevice,
ptx: String,
kernel_name: String,
) -> GpuResult<CudaSlice<u8>> {
launch_cmp::<u16>(a, b, device, ptx, kernel_name, "cmp_half")
}
fn launch_logic_bin(
a: &CudaSlice<u8>,
b: &CudaSlice<u8>,
device: &GpuDevice,
ptx: String,
kernel_name: &'static str,
) -> GpuResult<CudaSlice<u8>> {
launch_cmp::<u8>(a, b, device, ptx, kernel_name.to_string(), kernel_name)
}
fn launch_not(a: &CudaSlice<u8>, device: &GpuDevice) -> GpuResult<CudaSlice<u8>> {
let n = a.len();
let stream = device.stream();
if n == 0 {
return Ok(stream.alloc_zeros::<u8>(0)?);
}
let ctx = device.context();
let f = get_or_compile(
ctx,
NOT_BOOL_PTX,
"not_bool_kernel",
device.ordinal() as u32,
)
.map_err(|e| GpuError::PtxCompileFailed {
kernel: "not_bool_kernel",
source: e,
})?;
let mut out = stream.alloc_zeros::<u8>(n)?;
let cfg = launch_1d(n);
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(a)
.arg(&mut out)
.arg(&n_u32)
.launch(cfg)?;
}
Ok(out)
}
fn launch_reduce_bool(
a: &CudaSlice<u8>,
device: &GpuDevice,
op: u32,
empty_identity: u8,
) -> GpuResult<CudaSlice<u8>> {
let n = a.len();
let stream = device.stream();
if n == 0 {
let host = [empty_identity];
return Ok(stream.clone_htod(&host)?);
}
let ctx = device.context();
let f = get_or_compile(
ctx,
REDUCE_BOOL_PTX,
"reduce_bool_kernel",
device.ordinal() as u32,
)
.map_err(|e| GpuError::PtxCompileFailed {
kernel: "reduce_bool_kernel",
source: e,
})?;
let mut out = stream.alloc_zeros::<u8>(1)?;
let cfg = LaunchConfig {
grid_dim: (1, 1, 1),
block_dim: (1, 1, 1),
shared_mem_bytes: 0,
};
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(a)
.arg(&mut out)
.arg(&n_u32)
.arg(&op)
.launch(cfg)?;
}
Ok(out)
}
fn setp_for(op: &str, ty: &str) -> String {
format!("setp.{op}.{ty} %c, %va, %vb;")
}
pub fn gpu_cmp_f32(
a: &CudaSlice<f32>,
b: &CudaSlice<f32>,
op: &str,
d: &GpuDevice,
) -> GpuResult<CudaSlice<u8>> {
let name = format!("cmp_{op}_f32_kernel");
let ptx = cmp_ptx(&name, 2, "f32", ".reg .f32 %va, %vb;", &setp_for(op, "f32"));
launch_cmp::<f32>(a, b, d, ptx, name, "cmp_f32")
}
pub fn gpu_cmp_f64(
a: &CudaSlice<f64>,
b: &CudaSlice<f64>,
op: &str,
d: &GpuDevice,
) -> GpuResult<CudaSlice<u8>> {
let name = format!("cmp_{op}_f64_kernel");
let ptx = cmp_ptx(&name, 3, "f64", ".reg .f64 %va, %vb;", &setp_for(op, "f64"));
launch_cmp::<f64>(a, b, d, ptx, name, "cmp_f64")
}
pub fn gpu_cmp_i32(
a: &CudaSlice<i32>,
b: &CudaSlice<i32>,
op: &str,
d: &GpuDevice,
) -> GpuResult<CudaSlice<u8>> {
let name = format!("cmp_{op}_i32_kernel");
let ptx = cmp_ptx(&name, 2, "s32", ".reg .s32 %va, %vb;", &setp_for(op, "s32"));
launch_cmp::<i32>(a, b, d, ptx, name, "cmp_i32")
}
pub fn gpu_cmp_i64(
a: &CudaSlice<i64>,
b: &CudaSlice<i64>,
op: &str,
d: &GpuDevice,
) -> GpuResult<CudaSlice<u8>> {
let name = format!("cmp_{op}_i64_kernel");
let ptx = cmp_ptx(&name, 3, "s64", ".reg .s64 %va, %vb;", &setp_for(op, "s64"));
launch_cmp::<i64>(a, b, d, ptx, name, "cmp_i64")
}
pub fn gpu_cmp_bf16(
a: &CudaSlice<u16>,
b: &CudaSlice<u16>,
op: &str,
d: &GpuDevice,
) -> GpuResult<CudaSlice<u8>> {
let name = format!("cmp_{op}_bf16_kernel");
let setp = format!("setp.{op}.f32 %c, %fa, %fb;");
let ptx = cmp_half_ptx(&name, "sm_52", BF16_DECODE, &setp);
launch_cmp_half(a, b, d, ptx, name)
}
pub fn gpu_cmp_f16(
a: &CudaSlice<u16>,
b: &CudaSlice<u16>,
op: &str,
d: &GpuDevice,
) -> GpuResult<CudaSlice<u8>> {
let name = format!("cmp_{op}_f16_kernel");
let setp = format!("setp.{op}.f32 %c, %fa, %fb;");
let ptx = cmp_half_ptx(&name, "sm_53", F16_DECODE, &setp);
launch_cmp_half(a, b, d, ptx, name)
}
pub fn gpu_and_bool(
a: &CudaSlice<u8>,
b: &CudaSlice<u8>,
d: &GpuDevice,
) -> GpuResult<CudaSlice<u8>> {
let ptx = logic_bin_ptx("and_bool_kernel", "and");
launch_logic_bin(a, b, d, ptx, "and_bool_kernel")
}
pub fn gpu_or_bool(
a: &CudaSlice<u8>,
b: &CudaSlice<u8>,
d: &GpuDevice,
) -> GpuResult<CudaSlice<u8>> {
let ptx = logic_bin_ptx("or_bool_kernel", "or");
launch_logic_bin(a, b, d, ptx, "or_bool_kernel")
}
pub fn gpu_xor_bool(
a: &CudaSlice<u8>,
b: &CudaSlice<u8>,
d: &GpuDevice,
) -> GpuResult<CudaSlice<u8>> {
let ptx = logic_bin_ptx("xor_bool_kernel", "xor");
launch_logic_bin(a, b, d, ptx, "xor_bool_kernel")
}
pub fn gpu_not_bool(a: &CudaSlice<u8>, d: &GpuDevice) -> GpuResult<CudaSlice<u8>> {
launch_not(a, d)
}
pub fn gpu_any_bool(a: &CudaSlice<u8>, d: &GpuDevice) -> GpuResult<CudaSlice<u8>> {
launch_reduce_bool(a, d, REDUCE_ANY, 0)
}
pub fn gpu_all_bool(a: &CudaSlice<u8>, d: &GpuDevice) -> GpuResult<CudaSlice<u8>> {
launch_reduce_bool(a, d, REDUCE_ALL, 1)
}