#![cfg(feature = "cuda")]
use cudarc::driver::{CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg, ValidAsZeroBits};
use crate::device::GpuDevice;
use crate::error::{GpuError, GpuResult};
use crate::module_cache::get_or_compile;
const BLOCK_SIZE: u32 = 256;
fn launch_1d(n: usize) -> LaunchConfig {
let grid = ((n as u32).saturating_add(BLOCK_SIZE - 1)) / BLOCK_SIZE;
LaunchConfig {
grid_dim: (grid.max(1), 1, 1),
block_dim: (BLOCK_SIZE, 1, 1),
shared_mem_bytes: 0,
}
}
macro_rules! index_select_ptx {
($kname:literal, $wsh:literal, $ish:literal, $ldi:literal, $icvt:literal,
$ldv:literal, $stv:literal, $vreg:literal, $ireg:literal) => {
concat!(
".version 7.0\n.target sm_52\n.address_size 64\n",
".visible .entry ",
$kname,
"(
.param .u64 in_ptr, .param .u64 idx_ptr, .param .u64 out_ptr,
.param .u32 outer, .param .u32 in_dim, .param .u32 out_dim,
.param .u32 inner, .param .u32 total
) {
.reg .u32 %gtid, %bid, %bdim, %tot, %indim, %outdim, %inn;
.reg .u32 %o, %rem, %i, %k, %slab, %sel, %srcelem;
.reg .u64 %in, %idx, %out, %off, %addr;
.reg ",
$ireg,
" %selv;
.reg ",
$vreg,
" %v;
.reg .pred %p;
ld.param.u64 %in, [in_ptr];
ld.param.u64 %idx, [idx_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %tot, [total];
ld.param.u32 %indim, [in_dim];
ld.param.u32 %outdim, [out_dim];
ld.param.u32 %inn, [inner];
mov.u32 %bid, %ctaid.x; mov.u32 %bdim, %ntid.x; mov.u32 %gtid, %tid.x;
mad.lo.u32 %gtid, %bid, %bdim, %gtid;
setp.ge.u32 %p, %gtid, %tot; @%p bra DONE;
mul.lo.u32 %slab, %outdim, %inn;
div.u32 %o, %gtid, %slab;
rem.u32 %rem, %gtid, %slab;
div.u32 %i, %rem, %inn;
rem.u32 %k, %rem, %inn;
cvt.u64.u32 %off, %i; shl.b64 %off, %off, ",
$ish,
"; add.u64 %addr, %idx, %off;
",
$ldi,
" %selv, [%addr];
cvt.u32.",
$icvt,
" %sel, %selv;
mul.lo.u32 %srcelem, %o, %indim;
add.u32 %srcelem, %srcelem, %sel;
mul.lo.u32 %srcelem, %srcelem, %inn;
add.u32 %srcelem, %srcelem, %k;
cvt.u64.u32 %off, %srcelem; shl.b64 %off, %off, ",
$wsh,
"; add.u64 %addr, %in, %off;
",
$ldv,
" %v, [%addr];
cvt.u64.u32 %off, %gtid; shl.b64 %off, %off, ",
$wsh,
"; add.u64 %addr, %out, %off;
",
$stv,
" [%addr], %v;
DONE:
ret;
}
"
)
};
}
macro_rules! gather_ptx {
($kname:literal, $wsh:literal, $ish:literal, $ldi:literal, $icvt:literal,
$ldv:literal, $stv:literal, $vreg:literal, $ireg:literal) => {
concat!(
".version 7.0\n.target sm_52\n.address_size 64\n",
".visible .entry ",
$kname,
"(
.param .u64 in_ptr, .param .u64 idx_ptr, .param .u64 out_ptr,
.param .u32 outer, .param .u32 in_dim, .param .u32 out_dim,
.param .u32 inner, .param .u32 total
) {
.reg .u32 %gtid, %bid, %bdim, %tot, %indim, %outdim, %inn;
.reg .u32 %o, %rem, %k, %slab, %sel, %srcelem;
.reg .u64 %in, %idx, %out, %off, %addr;
.reg ",
$ireg,
" %selv;
.reg ",
$vreg,
" %v;
.reg .pred %p;
ld.param.u64 %in, [in_ptr];
ld.param.u64 %idx, [idx_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %tot, [total];
ld.param.u32 %indim, [in_dim];
ld.param.u32 %outdim, [out_dim];
ld.param.u32 %inn, [inner];
mov.u32 %bid, %ctaid.x; mov.u32 %bdim, %ntid.x; mov.u32 %gtid, %tid.x;
mad.lo.u32 %gtid, %bid, %bdim, %gtid;
setp.ge.u32 %p, %gtid, %tot; @%p bra DONE;
mul.lo.u32 %slab, %outdim, %inn;
div.u32 %o, %gtid, %slab;
rem.u32 %rem, %gtid, %slab;
rem.u32 %k, %rem, %inn;
cvt.u64.u32 %off, %gtid; shl.b64 %off, %off, ",
$ish,
"; add.u64 %addr, %idx, %off;
",
$ldi,
" %selv, [%addr];
cvt.u32.",
$icvt,
" %sel, %selv;
mul.lo.u32 %srcelem, %o, %indim;
add.u32 %srcelem, %srcelem, %sel;
mul.lo.u32 %srcelem, %srcelem, %inn;
add.u32 %srcelem, %srcelem, %k;
cvt.u64.u32 %off, %srcelem; shl.b64 %off, %off, ",
$wsh,
"; add.u64 %addr, %in, %off;
",
$ldv,
" %v, [%addr];
cvt.u64.u32 %off, %gtid; shl.b64 %off, %off, ",
$wsh,
"; add.u64 %addr, %out, %off;
",
$stv,
" [%addr], %v;
DONE:
ret;
}
"
)
};
}
const ISEL_W2_I32_PTX: &str = index_select_ptx!(
"isel_w2_i32_kernel",
"1",
"2",
"ld.global.s32",
"s32",
"ld.global.u16",
"st.global.u16",
".u16",
".s32"
);
const ISEL_W2_I64_PTX: &str = index_select_ptx!(
"isel_w2_i64_kernel",
"1",
"3",
"ld.global.s64",
"s64",
"ld.global.u16",
"st.global.u16",
".u16",
".s64"
);
const ISEL_W4_I32_PTX: &str = index_select_ptx!(
"isel_w4_i32_kernel",
"2",
"2",
"ld.global.s32",
"s32",
"ld.global.u32",
"st.global.u32",
".u32",
".s32"
);
const ISEL_W4_I64_PTX: &str = index_select_ptx!(
"isel_w4_i64_kernel",
"2",
"3",
"ld.global.s64",
"s64",
"ld.global.u32",
"st.global.u32",
".u32",
".s64"
);
const ISEL_W8_I32_PTX: &str = index_select_ptx!(
"isel_w8_i32_kernel",
"3",
"2",
"ld.global.s32",
"s32",
"ld.global.u64",
"st.global.u64",
".u64",
".s32"
);
const ISEL_W8_I64_PTX: &str = index_select_ptx!(
"isel_w8_i64_kernel",
"3",
"3",
"ld.global.s64",
"s64",
"ld.global.u64",
"st.global.u64",
".u64",
".s64"
);
const GATHER_W2_I32_PTX: &str = gather_ptx!(
"gather_w2_i32_kernel",
"1",
"2",
"ld.global.s32",
"s32",
"ld.global.u16",
"st.global.u16",
".u16",
".s32"
);
const GATHER_W2_I64_PTX: &str = gather_ptx!(
"gather_w2_i64_kernel",
"1",
"3",
"ld.global.s64",
"s64",
"ld.global.u16",
"st.global.u16",
".u16",
".s64"
);
const GATHER_W4_I32_PTX: &str = gather_ptx!(
"gather_w4_i32_kernel",
"2",
"2",
"ld.global.s32",
"s32",
"ld.global.u32",
"st.global.u32",
".u32",
".s32"
);
const GATHER_W4_I64_PTX: &str = gather_ptx!(
"gather_w4_i64_kernel",
"2",
"3",
"ld.global.s64",
"s64",
"ld.global.u32",
"st.global.u32",
".u32",
".s64"
);
const GATHER_W8_I32_PTX: &str = gather_ptx!(
"gather_w8_i32_kernel",
"3",
"2",
"ld.global.s32",
"s32",
"ld.global.u64",
"st.global.u64",
".u64",
".s32"
);
const GATHER_W8_I64_PTX: &str = gather_ptx!(
"gather_w8_i64_kernel",
"3",
"3",
"ld.global.s64",
"s64",
"ld.global.u64",
"st.global.u64",
".u64",
".s64"
);
#[derive(Clone, Copy, PartialEq, Eq)]
enum ValWidth {
W2,
W4,
W8,
}
#[derive(Clone, Copy, PartialEq, Eq)]
enum IdxWidth {
I32,
I64,
}
#[allow(clippy::too_many_arguments)]
fn launch_select<V: DeviceRepr + ValidAsZeroBits, I: DeviceRepr + ValidAsZeroBits>(
input: &CudaSlice<V>,
idx: &CudaSlice<I>,
outer: usize,
in_dim: usize,
out_dim: usize,
inner: usize,
device: &GpuDevice,
ptx: &'static str,
kernel_name: &'static str,
) -> GpuResult<CudaSlice<V>> {
let total = outer
.checked_mul(out_dim)
.and_then(|x| x.checked_mul(inner))
.ok_or(GpuError::LengthMismatch {
a: outer,
b: out_dim,
})?;
let stream = device.stream();
if total == 0 {
return Ok(stream.alloc_zeros::<V>(0)?);
}
let ctx = device.context();
let f = get_or_compile(ctx, ptx, kernel_name, device.ordinal() as u32).map_err(|e| {
GpuError::PtxCompileFailed {
kernel: kernel_name,
source: e,
}
})?;
let mut out = stream.alloc_zeros::<V>(total)?;
let cfg = launch_1d(total);
let (outer_u, indim_u, outdim_u, inner_u, total_u) = (
outer as u32,
in_dim as u32,
out_dim as u32,
inner as u32,
total as u32,
);
unsafe {
stream
.launch_builder(&f)
.arg(input)
.arg(idx)
.arg(&mut out)
.arg(&outer_u)
.arg(&indim_u)
.arg(&outdim_u)
.arg(&inner_u)
.arg(&total_u)
.launch(cfg)?;
}
Ok(out)
}
fn isel_ptx(vw: ValWidth, iw: IdxWidth) -> (&'static str, &'static str) {
match (vw, iw) {
(ValWidth::W2, IdxWidth::I32) => (ISEL_W2_I32_PTX, "isel_w2_i32_kernel"),
(ValWidth::W2, IdxWidth::I64) => (ISEL_W2_I64_PTX, "isel_w2_i64_kernel"),
(ValWidth::W4, IdxWidth::I32) => (ISEL_W4_I32_PTX, "isel_w4_i32_kernel"),
(ValWidth::W4, IdxWidth::I64) => (ISEL_W4_I64_PTX, "isel_w4_i64_kernel"),
(ValWidth::W8, IdxWidth::I32) => (ISEL_W8_I32_PTX, "isel_w8_i32_kernel"),
(ValWidth::W8, IdxWidth::I64) => (ISEL_W8_I64_PTX, "isel_w8_i64_kernel"),
}
}
fn gathr_ptx(vw: ValWidth, iw: IdxWidth) -> (&'static str, &'static str) {
match (vw, iw) {
(ValWidth::W2, IdxWidth::I32) => (GATHER_W2_I32_PTX, "gather_w2_i32_kernel"),
(ValWidth::W2, IdxWidth::I64) => (GATHER_W2_I64_PTX, "gather_w2_i64_kernel"),
(ValWidth::W4, IdxWidth::I32) => (GATHER_W4_I32_PTX, "gather_w4_i32_kernel"),
(ValWidth::W4, IdxWidth::I64) => (GATHER_W4_I64_PTX, "gather_w4_i64_kernel"),
(ValWidth::W8, IdxWidth::I32) => (GATHER_W8_I32_PTX, "gather_w8_i32_kernel"),
(ValWidth::W8, IdxWidth::I64) => (GATHER_W8_I64_PTX, "gather_w8_i64_kernel"),
}
}
macro_rules! select_entry {
($name:ident, $vty:ty, $vw:expr, $idxty:ty, $iw:expr, $sel:ident) => {
#[doc = concat!("`", stringify!($sel), "` on a ", stringify!($vty), " value buffer with a ", stringify!($idxty), " index buffer.")]
#[allow(clippy::too_many_arguments)]
pub fn $name(
input: &CudaSlice<$vty>,
idx: &CudaSlice<$idxty>,
outer: usize,
in_dim: usize,
out_dim: usize,
inner: usize,
d: &GpuDevice,
) -> GpuResult<CudaSlice<$vty>> {
let (ptx, name) = $sel($vw, $iw);
launch_select(input, idx, outer, in_dim, out_dim, inner, d, ptx, name)
}
};
}
select_entry!(
isel_f32_i32,
f32,
ValWidth::W4,
i32,
IdxWidth::I32,
isel_ptx
);
select_entry!(
isel_f32_i64,
f32,
ValWidth::W4,
i64,
IdxWidth::I64,
isel_ptx
);
select_entry!(
isel_f64_i32,
f64,
ValWidth::W8,
i32,
IdxWidth::I32,
isel_ptx
);
select_entry!(
isel_f64_i64,
f64,
ValWidth::W8,
i64,
IdxWidth::I64,
isel_ptx
);
select_entry!(
isel_i32_i32,
i32,
ValWidth::W4,
i32,
IdxWidth::I32,
isel_ptx
);
select_entry!(
isel_i32_i64,
i32,
ValWidth::W4,
i64,
IdxWidth::I64,
isel_ptx
);
select_entry!(
isel_i64_i32,
i64,
ValWidth::W8,
i32,
IdxWidth::I32,
isel_ptx
);
select_entry!(
isel_i64_i64,
i64,
ValWidth::W8,
i64,
IdxWidth::I64,
isel_ptx
);
select_entry!(
isel_u16_i32,
u16,
ValWidth::W2,
i32,
IdxWidth::I32,
isel_ptx
);
select_entry!(
isel_u16_i64,
u16,
ValWidth::W2,
i64,
IdxWidth::I64,
isel_ptx
);
select_entry!(
gather_f32_i32,
f32,
ValWidth::W4,
i32,
IdxWidth::I32,
gathr_ptx
);
select_entry!(
gather_f32_i64,
f32,
ValWidth::W4,
i64,
IdxWidth::I64,
gathr_ptx
);
select_entry!(
gather_f64_i32,
f64,
ValWidth::W8,
i32,
IdxWidth::I32,
gathr_ptx
);
select_entry!(
gather_f64_i64,
f64,
ValWidth::W8,
i64,
IdxWidth::I64,
gathr_ptx
);
select_entry!(
gather_i32_i32,
i32,
ValWidth::W4,
i32,
IdxWidth::I32,
gathr_ptx
);
select_entry!(
gather_i32_i64,
i32,
ValWidth::W4,
i64,
IdxWidth::I64,
gathr_ptx
);
select_entry!(
gather_i64_i32,
i64,
ValWidth::W8,
i32,
IdxWidth::I32,
gathr_ptx
);
select_entry!(
gather_i64_i64,
i64,
ValWidth::W8,
i64,
IdxWidth::I64,
gathr_ptx
);
select_entry!(
gather_u16_i32,
u16,
ValWidth::W2,
i32,
IdxWidth::I32,
gathr_ptx
);
select_entry!(
gather_u16_i64,
u16,
ValWidth::W2,
i64,
IdxWidth::I64,
gathr_ptx
);
#[cfg(test)]
mod tests {
use super::*;
fn dev() -> GpuDevice {
GpuDevice::new(0).expect("cuda device")
}
#[test]
fn index_select_dim0_f32_i64() {
let d = dev();
let inp = d
.stream()
.clone_htod(&vec![0.0f32, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0])
.unwrap();
let idx = d.stream().clone_htod(&vec![2i64, 0, 2]).unwrap();
let out = isel_f32_i64(&inp, &idx, 1, 4, 3, 2, &d).unwrap();
assert_eq!(
d.stream().clone_dtoh(&out).unwrap(),
vec![4.0f32, 5.0, 0.0, 1.0, 4.0, 5.0]
);
}
#[test]
fn index_select_dim1_f32_i32() {
let d = dev();
let inp = d
.stream()
.clone_htod(&vec![10.0f32, 11.0, 12.0, 20.0, 21.0, 22.0])
.unwrap();
let idx = d.stream().clone_htod(&vec![2i32, 0]).unwrap();
let out = isel_f32_i32(&inp, &idx, 2, 3, 2, 1, &d).unwrap();
assert_eq!(
d.stream().clone_dtoh(&out).unwrap(),
vec![12.0f32, 10.0, 22.0, 20.0]
);
}
#[test]
fn gather_dim1_i32_values() {
let d = dev();
let inp = d.stream().clone_htod(&vec![5i32, 6, 7, 8, 9, 10]).unwrap();
let idx = d.stream().clone_htod(&vec![0i64, 2, 2, 1]).unwrap();
let out = gather_i32_i64(&inp, &idx, 2, 3, 2, 1, &d).unwrap();
assert_eq!(d.stream().clone_dtoh(&out).unwrap(), vec![5i32, 7, 10, 9]);
}
#[test]
fn index_select_f64_and_i64_values() {
let d = dev();
let inp = d.stream().clone_htod(&vec![1.5f64, 2.5, 3.5, 4.5]).unwrap();
let idx = d.stream().clone_htod(&vec![1i32, 0]).unwrap();
let out = isel_f64_i32(&inp, &idx, 1, 2, 2, 2, &d).unwrap();
assert_eq!(
d.stream().clone_dtoh(&out).unwrap(),
vec![3.5f64, 4.5, 1.5, 2.5]
);
}
}