#![cfg(feature = "cuda")]
use cudarc::driver::{CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg, ValidAsZeroBits};
use crate::buffer::CudaBuffer;
use crate::device::GpuDevice;
use crate::error::{GpuError, GpuResult};
use crate::module_cache::get_or_compile;
use crate::transfer::{alloc_zeros_f32, alloc_zeros_f64};
const BLOCK_SIZE: u32 = 256;
fn launch_1d(n: usize) -> LaunchConfig {
let grid = ((n as u32).saturating_add(BLOCK_SIZE - 1)) / BLOCK_SIZE;
LaunchConfig {
grid_dim: (grid.max(1), 1, 1),
block_dim: (BLOCK_SIZE, 1, 1),
shared_mem_bytes: 0,
}
}
macro_rules! gather_dim_ptx {
($kname:literal, $wsh:literal, $ldv:literal, $stv:literal, $vreg:literal) => {
concat!(
".version 7.0\n.target sm_60\n.address_size 64\n",
".visible .entry ",
$kname,
"(
.param .u64 in_ptr, .param .u64 idx_ptr, .param .u64 out_ptr,
.param .u32 outer, .param .u32 in_dim, .param .u32 out_dim,
.param .u32 inner, .param .u32 total
) {
.reg .u32 %gtid, %bid, %bdim, %tot, %indim, %outdim, %inn;
.reg .u32 %o, %rem, %k, %slab, %sel, %srcelem;
.reg .u64 %in, %idx, %out, %off, %addr;
.reg .s64 %selv;
.reg ",
$vreg,
" %v;
.reg .pred %p;
ld.param.u64 %in, [in_ptr];
ld.param.u64 %idx, [idx_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %tot, [total];
ld.param.u32 %indim, [in_dim];
ld.param.u32 %outdim, [out_dim];
ld.param.u32 %inn, [inner];
mov.u32 %bid, %ctaid.x; mov.u32 %bdim, %ntid.x; mov.u32 %gtid, %tid.x;
mad.lo.u32 %gtid, %bid, %bdim, %gtid;
setp.ge.u32 %p, %gtid, %tot; @%p bra DONE;
mul.lo.u32 %slab, %outdim, %inn;
div.u32 %o, %gtid, %slab;
rem.u32 %rem, %gtid, %slab;
rem.u32 %k, %rem, %inn;
cvt.u64.u32 %off, %gtid; shl.b64 %off, %off, 3; add.u64 %addr, %idx, %off;
ld.global.s64 %selv, [%addr];
cvt.u32.s64 %sel, %selv;
mul.lo.u32 %srcelem, %o, %indim;
add.u32 %srcelem, %srcelem, %sel;
mul.lo.u32 %srcelem, %srcelem, %inn;
add.u32 %srcelem, %srcelem, %k;
cvt.u64.u32 %off, %srcelem; shl.b64 %off, %off, ",
$wsh,
"; add.u64 %addr, %in, %off;
",
$ldv,
" %v, [%addr];
cvt.u64.u32 %off, %gtid; shl.b64 %off, %off, ",
$wsh,
"; add.u64 %addr, %out, %off;
",
$stv,
" [%addr], %v;
DONE:
ret;
}
"
)
};
}
macro_rules! scatter_dim_ptx {
($kname:literal, $wsh:literal, $ldv:literal, $stv:literal, $vreg:literal) => {
concat!(
".version 7.0\n.target sm_60\n.address_size 64\n",
".visible .entry ",
$kname,
"(
.param .u64 out_ptr, .param .u64 idx_ptr, .param .u64 src_ptr,
.param .u32 outer, .param .u32 out_dim, .param .u32 idx_dim,
.param .u32 inner, .param .u32 total
) {
.reg .u32 %gtid, %bid, %bdim, %tot, %outdim, %idxdim, %inn;
.reg .u32 %o, %rem, %k, %slab, %sel, %dstelem;
.reg .u64 %out, %idx, %src, %off, %addr;
.reg .s64 %selv;
.reg ",
$vreg,
" %v;
.reg .pred %p;
ld.param.u64 %out, [out_ptr];
ld.param.u64 %idx, [idx_ptr];
ld.param.u64 %src, [src_ptr];
ld.param.u32 %tot, [total];
ld.param.u32 %outdim, [out_dim];
ld.param.u32 %idxdim, [idx_dim];
ld.param.u32 %inn, [inner];
mov.u32 %bid, %ctaid.x; mov.u32 %bdim, %ntid.x; mov.u32 %gtid, %tid.x;
mad.lo.u32 %gtid, %bid, %bdim, %gtid;
setp.ge.u32 %p, %gtid, %tot; @%p bra DONE;
mul.lo.u32 %slab, %idxdim, %inn;
div.u32 %o, %gtid, %slab;
rem.u32 %rem, %gtid, %slab;
rem.u32 %k, %rem, %inn;
cvt.u64.u32 %off, %gtid; shl.b64 %off, %off, 3; add.u64 %addr, %idx, %off;
ld.global.s64 %selv, [%addr];
cvt.u32.s64 %sel, %selv;
cvt.u64.u32 %off, %gtid; shl.b64 %off, %off, ",
$wsh,
"; add.u64 %addr, %src, %off;
",
$ldv,
" %v, [%addr];
mul.lo.u32 %dstelem, %o, %outdim;
add.u32 %dstelem, %dstelem, %sel;
mul.lo.u32 %dstelem, %dstelem, %inn;
add.u32 %dstelem, %dstelem, %k;
cvt.u64.u32 %off, %dstelem; shl.b64 %off, %off, ",
$wsh,
"; add.u64 %addr, %out, %off;
",
$stv,
" [%addr], %v;
DONE:
ret;
}
"
)
};
}
macro_rules! scatter_value_dim_ptx {
($kname:literal, $wsh:literal, $stv:literal, $vreg:literal, $valparam:literal,
$ldval:literal) => {
concat!(
".version 7.0\n.target sm_60\n.address_size 64\n",
".visible .entry ",
$kname,
"(
.param .u64 out_ptr, .param .u64 idx_ptr, .param ",
$valparam,
" value,
.param .u32 outer, .param .u32 out_dim, .param .u32 idx_dim,
.param .u32 inner, .param .u32 total
) {
.reg .u32 %gtid, %bid, %bdim, %tot, %outdim, %idxdim, %inn;
.reg .u32 %o, %rem, %k, %slab, %sel, %dstelem;
.reg .u64 %out, %idx, %off, %addr;
.reg .s64 %selv;
.reg ",
$vreg,
" %v;
.reg .pred %p;
ld.param.u64 %out, [out_ptr];
ld.param.u64 %idx, [idx_ptr];
",
$ldval,
" %v, [value];
ld.param.u32 %tot, [total];
ld.param.u32 %outdim, [out_dim];
ld.param.u32 %idxdim, [idx_dim];
ld.param.u32 %inn, [inner];
mov.u32 %bid, %ctaid.x; mov.u32 %bdim, %ntid.x; mov.u32 %gtid, %tid.x;
mad.lo.u32 %gtid, %bid, %bdim, %gtid;
setp.ge.u32 %p, %gtid, %tot; @%p bra DONE;
mul.lo.u32 %slab, %idxdim, %inn;
div.u32 %o, %gtid, %slab;
rem.u32 %rem, %gtid, %slab;
rem.u32 %k, %rem, %inn;
cvt.u64.u32 %off, %gtid; shl.b64 %off, %off, 3; add.u64 %addr, %idx, %off;
ld.global.s64 %selv, [%addr];
cvt.u32.s64 %sel, %selv;
mul.lo.u32 %dstelem, %o, %outdim;
add.u32 %dstelem, %dstelem, %sel;
mul.lo.u32 %dstelem, %dstelem, %inn;
add.u32 %dstelem, %dstelem, %k;
cvt.u64.u32 %off, %dstelem; shl.b64 %off, %off, ",
$wsh,
"; add.u64 %addr, %out, %off;
",
$stv,
" [%addr], %v;
DONE:
ret;
}
"
)
};
}
macro_rules! scatter_add_dim_ptx {
($kname:literal, $wsh:literal, $ldv:literal, $atom:literal, $vreg:literal) => {
concat!(
".version 7.0\n.target sm_60\n.address_size 64\n",
".visible .entry ",
$kname,
"(
.param .u64 out_ptr, .param .u64 idx_ptr, .param .u64 src_ptr,
.param .u32 outer, .param .u32 out_dim, .param .u32 idx_dim,
.param .u32 inner, .param .u32 total
) {
.reg .u32 %gtid, %bid, %bdim, %tot, %outdim, %idxdim, %inn;
.reg .u32 %o, %rem, %k, %slab, %sel, %dstelem;
.reg .u64 %out, %idx, %src, %off, %addr;
.reg .s64 %selv;
.reg ",
$vreg,
" %v, %dummy;
.reg .pred %p;
ld.param.u64 %out, [out_ptr];
ld.param.u64 %idx, [idx_ptr];
ld.param.u64 %src, [src_ptr];
ld.param.u32 %tot, [total];
ld.param.u32 %outdim, [out_dim];
ld.param.u32 %idxdim, [idx_dim];
ld.param.u32 %inn, [inner];
mov.u32 %bid, %ctaid.x; mov.u32 %bdim, %ntid.x; mov.u32 %gtid, %tid.x;
mad.lo.u32 %gtid, %bid, %bdim, %gtid;
setp.ge.u32 %p, %gtid, %tot; @%p bra DONE;
mul.lo.u32 %slab, %idxdim, %inn;
div.u32 %o, %gtid, %slab;
rem.u32 %rem, %gtid, %slab;
rem.u32 %k, %rem, %inn;
cvt.u64.u32 %off, %gtid; shl.b64 %off, %off, 3; add.u64 %addr, %idx, %off;
ld.global.s64 %selv, [%addr];
cvt.u32.s64 %sel, %selv;
cvt.u64.u32 %off, %gtid; shl.b64 %off, %off, ",
$wsh,
"; add.u64 %addr, %src, %off;
",
$ldv,
" %v, [%addr];
mul.lo.u32 %dstelem, %o, %outdim;
add.u32 %dstelem, %dstelem, %sel;
mul.lo.u32 %dstelem, %dstelem, %inn;
add.u32 %dstelem, %dstelem, %k;
cvt.u64.u32 %off, %dstelem; shl.b64 %off, %off, ",
$wsh,
"; add.u64 %addr, %out, %off;
",
$atom,
" %dummy, [%addr], %v;
DONE:
ret;
}
"
)
};
}
const GATHER_DIM_F32_PTX: &str = gather_dim_ptx!(
"gather_dim_f32_kernel",
"2",
"ld.global.f32",
"st.global.f32",
".f32"
);
const GATHER_DIM_F64_PTX: &str = gather_dim_ptx!(
"gather_dim_f64_kernel",
"3",
"ld.global.f64",
"st.global.f64",
".f64"
);
const SCATTER_DIM_F32_PTX: &str = scatter_dim_ptx!(
"scatter_dim_f32_kernel",
"2",
"ld.global.f32",
"st.global.f32",
".f32"
);
const SCATTER_DIM_F64_PTX: &str = scatter_dim_ptx!(
"scatter_dim_f64_kernel",
"3",
"ld.global.f64",
"st.global.f64",
".f64"
);
const SCATTER_VALUE_DIM_F32_PTX: &str = scatter_value_dim_ptx!(
"scatter_value_dim_f32_kernel",
"2",
"st.global.f32",
".f32",
".f32",
"ld.param.f32"
);
const SCATTER_VALUE_DIM_F64_PTX: &str = scatter_value_dim_ptx!(
"scatter_value_dim_f64_kernel",
"3",
"st.global.f64",
".f64",
".f64",
"ld.param.f64"
);
const SCATTER_ADD_DIM_F32_PTX: &str = scatter_add_dim_ptx!(
"scatter_add_dim_f32_kernel",
"2",
"ld.global.f32",
"atom.global.add.f32",
".f32"
);
const SCATTER_ADD_DIM_F64_PTX: &str = scatter_add_dim_ptx!(
"scatter_add_dim_f64_kernel",
"3",
"ld.global.f64",
"atom.global.add.f64",
".f64"
);
macro_rules! scatter_add_segments_ptx {
($kname:literal, $wsh:literal, $ldv:literal, $atom:literal, $vreg:literal) => {
concat!(
".version 7.0\n.target sm_60\n.address_size 64\n",
".visible .entry ",
$kname,
"(
.param .u64 out_ptr, .param .u64 idx_ptr, .param .u64 src_ptr,
.param .u32 e, .param .u32 d, .param .u32 total
) {
.reg .u32 %gtid, %bid, %bdim, %tot, %dd, %row, %col, %seg, %dstelem;
.reg .u64 %out, %idx, %src, %off, %addr;
.reg .s64 %segv;
.reg ",
$vreg,
" %v, %dummy;
.reg .pred %p;
ld.param.u64 %out, [out_ptr];
ld.param.u64 %idx, [idx_ptr];
ld.param.u64 %src, [src_ptr];
ld.param.u32 %tot, [total];
ld.param.u32 %dd, [d];
mov.u32 %bid, %ctaid.x; mov.u32 %bdim, %ntid.x; mov.u32 %gtid, %tid.x;
mad.lo.u32 %gtid, %bid, %bdim, %gtid;
setp.ge.u32 %p, %gtid, %tot; @%p bra DONE;
div.u32 %row, %gtid, %dd;
rem.u32 %col, %gtid, %dd;
cvt.u64.u32 %off, %row; shl.b64 %off, %off, 3; add.u64 %addr, %idx, %off;
ld.global.s64 %segv, [%addr];
cvt.u32.s64 %seg, %segv;
cvt.u64.u32 %off, %gtid; shl.b64 %off, %off, ",
$wsh,
"; add.u64 %addr, %src, %off;
",
$ldv,
" %v, [%addr];
mul.lo.u32 %dstelem, %seg, %dd;
add.u32 %dstelem, %dstelem, %col;
cvt.u64.u32 %off, %dstelem; shl.b64 %off, %off, ",
$wsh,
"; add.u64 %addr, %out, %off;
",
$atom,
" %dummy, [%addr], %v;
DONE:
ret;
}
"
)
};
}
const SCATTER_ADD_SEGMENTS_F32_PTX: &str = scatter_add_segments_ptx!(
"scatter_add_segments_f32_kernel",
"2",
"ld.global.f32",
"atom.global.add.f32",
".f32"
);
const SCATTER_ADD_SEGMENTS_F64_PTX: &str = scatter_add_segments_ptx!(
"scatter_add_segments_f64_kernel",
"3",
"ld.global.f64",
"atom.global.add.f64",
".f64"
);
#[allow(clippy::too_many_arguments)]
fn launch_gather<V: DeviceRepr + ValidAsZeroBits>(
input: &CudaSlice<V>,
idx: &CudaSlice<i64>,
out: &mut CudaSlice<V>,
outer: usize,
in_dim: usize,
out_dim: usize,
inner: usize,
device: &GpuDevice,
ptx: &'static str,
kernel_name: &'static str,
) -> GpuResult<()> {
let total = outer
.checked_mul(out_dim)
.and_then(|x| x.checked_mul(inner))
.ok_or(GpuError::LengthMismatch {
a: outer,
b: out_dim,
})?;
if total == 0 {
return Ok(());
}
let stream = device.stream();
let ctx = device.context();
let f = get_or_compile(ctx, ptx, kernel_name, device.ordinal() as u32).map_err(|e| {
GpuError::PtxCompileFailed {
kernel: kernel_name,
source: e,
}
})?;
let cfg = launch_1d(total);
let (outer_u, indim_u, outdim_u, inner_u, total_u) = (
outer as u32,
in_dim as u32,
out_dim as u32,
inner as u32,
total as u32,
);
unsafe {
stream
.launch_builder(&f)
.arg(input)
.arg(idx)
.arg(out)
.arg(&outer_u)
.arg(&indim_u)
.arg(&outdim_u)
.arg(&inner_u)
.arg(&total_u)
.launch(cfg)?;
}
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn gpu_gather_dim_f32(
input: &CudaBuffer<f32>,
idx: &CudaSlice<i64>,
outer: usize,
in_dim: usize,
out_dim: usize,
inner: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
let mut out = alloc_zeros_f32(outer * out_dim * inner, device)?;
launch_gather(
input.inner(),
idx,
out.inner_mut(),
outer,
in_dim,
out_dim,
inner,
device,
GATHER_DIM_F32_PTX,
"gather_dim_f32_kernel",
)?;
Ok(out)
}
#[allow(clippy::too_many_arguments)]
pub fn gpu_gather_dim_f64(
input: &CudaBuffer<f64>,
idx: &CudaSlice<i64>,
outer: usize,
in_dim: usize,
out_dim: usize,
inner: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
let mut out = alloc_zeros_f64(outer * out_dim * inner, device)?;
launch_gather(
input.inner(),
idx,
out.inner_mut(),
outer,
in_dim,
out_dim,
inner,
device,
GATHER_DIM_F64_PTX,
"gather_dim_f64_kernel",
)?;
Ok(out)
}
#[allow(clippy::too_many_arguments)]
fn launch_scatter<V: DeviceRepr + ValidAsZeroBits>(
out: &mut CudaSlice<V>,
idx: &CudaSlice<i64>,
src: &CudaSlice<V>,
outer: usize,
out_dim: usize,
idx_dim: usize,
inner: usize,
device: &GpuDevice,
ptx: &'static str,
kernel_name: &'static str,
) -> GpuResult<()> {
let total = outer
.checked_mul(idx_dim)
.and_then(|x| x.checked_mul(inner))
.ok_or(GpuError::LengthMismatch {
a: outer,
b: idx_dim,
})?;
if total == 0 {
return Ok(());
}
let stream = device.stream();
let ctx = device.context();
let f = get_or_compile(ctx, ptx, kernel_name, device.ordinal() as u32).map_err(|e| {
GpuError::PtxCompileFailed {
kernel: kernel_name,
source: e,
}
})?;
let cfg = launch_1d(total);
let (outer_u, outdim_u, idxdim_u, inner_u, total_u) = (
outer as u32,
out_dim as u32,
idx_dim as u32,
inner as u32,
total as u32,
);
unsafe {
stream
.launch_builder(&f)
.arg(out)
.arg(idx)
.arg(src)
.arg(&outer_u)
.arg(&outdim_u)
.arg(&idxdim_u)
.arg(&inner_u)
.arg(&total_u)
.launch(cfg)?;
}
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn gpu_scatter_dim_f32(
input: &CudaBuffer<f32>,
idx: &CudaSlice<i64>,
src: &CudaBuffer<f32>,
outer: usize,
out_dim: usize,
idx_dim: usize,
inner: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
let mut out = clone_f32(input, outer * out_dim * inner, device)?;
launch_scatter(
out.inner_mut(),
idx,
src.inner(),
outer,
out_dim,
idx_dim,
inner,
device,
SCATTER_DIM_F32_PTX,
"scatter_dim_f32_kernel",
)?;
Ok(out)
}
#[allow(clippy::too_many_arguments)]
pub fn gpu_scatter_dim_f64(
input: &CudaBuffer<f64>,
idx: &CudaSlice<i64>,
src: &CudaBuffer<f64>,
outer: usize,
out_dim: usize,
idx_dim: usize,
inner: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
let mut out = clone_f64(input, outer * out_dim * inner, device)?;
launch_scatter(
out.inner_mut(),
idx,
src.inner(),
outer,
out_dim,
idx_dim,
inner,
device,
SCATTER_DIM_F64_PTX,
"scatter_dim_f64_kernel",
)?;
Ok(out)
}
#[allow(clippy::too_many_arguments)]
pub fn gpu_scatter_add_dim_f32(
input: &CudaBuffer<f32>,
idx: &CudaSlice<i64>,
src: &CudaBuffer<f32>,
outer: usize,
out_dim: usize,
idx_dim: usize,
inner: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
let mut out = clone_f32(input, outer * out_dim * inner, device)?;
launch_scatter(
out.inner_mut(),
idx,
src.inner(),
outer,
out_dim,
idx_dim,
inner,
device,
SCATTER_ADD_DIM_F32_PTX,
"scatter_add_dim_f32_kernel",
)?;
Ok(out)
}
#[allow(clippy::too_many_arguments)]
pub fn gpu_scatter_add_dim_f64(
input: &CudaBuffer<f64>,
idx: &CudaSlice<i64>,
src: &CudaBuffer<f64>,
outer: usize,
out_dim: usize,
idx_dim: usize,
inner: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
let mut out = clone_f64(input, outer * out_dim * inner, device)?;
launch_scatter(
out.inner_mut(),
idx,
src.inner(),
outer,
out_dim,
idx_dim,
inner,
device,
SCATTER_ADD_DIM_F64_PTX,
"scatter_add_dim_f64_kernel",
)?;
Ok(out)
}
fn launch_scatter_add_segments<V: DeviceRepr + ValidAsZeroBits>(
out: &mut CudaSlice<V>,
idx: &CudaSlice<i64>,
src: &CudaSlice<V>,
e: usize,
d: usize,
device: &GpuDevice,
ptx: &'static str,
kernel_name: &'static str,
) -> GpuResult<()> {
let total = e
.checked_mul(d)
.ok_or(GpuError::LengthMismatch { a: e, b: d })?;
if total == 0 {
return Ok(());
}
let stream = device.stream();
let ctx = device.context();
let f = get_or_compile(ctx, ptx, kernel_name, device.ordinal() as u32).map_err(|err| {
GpuError::PtxCompileFailed {
kernel: kernel_name,
source: err,
}
})?;
let cfg = launch_1d(total);
let (e_u, d_u, total_u) = (e as u32, d as u32, total as u32);
unsafe {
stream
.launch_builder(&f)
.arg(out)
.arg(idx)
.arg(src)
.arg(&e_u)
.arg(&d_u)
.arg(&total_u)
.launch(cfg)?;
}
Ok(())
}
pub fn gpu_scatter_add_segments_f32(
src: &CudaBuffer<f32>,
idx: &CudaSlice<i64>,
e: usize,
d: usize,
dim_size: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
let mut out = alloc_zeros_f32(dim_size * d, device)?;
launch_scatter_add_segments(
out.inner_mut(),
idx,
src.inner(),
e,
d,
device,
SCATTER_ADD_SEGMENTS_F32_PTX,
"scatter_add_segments_f32_kernel",
)?;
Ok(out)
}
pub fn gpu_scatter_add_segments_f64(
src: &CudaBuffer<f64>,
idx: &CudaSlice<i64>,
e: usize,
d: usize,
dim_size: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
let mut out = alloc_zeros_f64(dim_size * d, device)?;
launch_scatter_add_segments(
out.inner_mut(),
idx,
src.inner(),
e,
d,
device,
SCATTER_ADD_SEGMENTS_F64_PTX,
"scatter_add_segments_f64_kernel",
)?;
Ok(out)
}
#[allow(clippy::too_many_arguments)]
pub fn gpu_scatter_value_dim_f32(
input: &CudaBuffer<f32>,
idx: &CudaSlice<i64>,
value: f32,
outer: usize,
out_dim: usize,
idx_dim: usize,
inner: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
let mut out = clone_f32(input, outer * out_dim * inner, device)?;
launch_scatter_value_f32(
out.inner_mut(),
idx,
value,
outer,
out_dim,
idx_dim,
inner,
device,
)?;
Ok(out)
}
#[allow(clippy::too_many_arguments)]
pub fn gpu_scatter_value_dim_f64(
input: &CudaBuffer<f64>,
idx: &CudaSlice<i64>,
value: f64,
outer: usize,
out_dim: usize,
idx_dim: usize,
inner: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
let mut out = clone_f64(input, outer * out_dim * inner, device)?;
launch_scatter_value_f64(
out.inner_mut(),
idx,
value,
outer,
out_dim,
idx_dim,
inner,
device,
)?;
Ok(out)
}
#[allow(clippy::too_many_arguments)]
fn launch_scatter_value_f32(
out: &mut CudaSlice<f32>,
idx: &CudaSlice<i64>,
value: f32,
outer: usize,
out_dim: usize,
idx_dim: usize,
inner: usize,
device: &GpuDevice,
) -> GpuResult<()> {
let total = outer
.checked_mul(idx_dim)
.and_then(|x| x.checked_mul(inner))
.ok_or(GpuError::LengthMismatch {
a: outer,
b: idx_dim,
})?;
if total == 0 {
return Ok(());
}
let stream = device.stream();
let ctx = device.context();
let f = get_or_compile(
ctx,
SCATTER_VALUE_DIM_F32_PTX,
"scatter_value_dim_f32_kernel",
device.ordinal() as u32,
)
.map_err(|e| GpuError::PtxCompileFailed {
kernel: "scatter_value_dim_f32_kernel",
source: e,
})?;
let cfg = launch_1d(total);
let (outer_u, outdim_u, idxdim_u, inner_u, total_u) = (
outer as u32,
out_dim as u32,
idx_dim as u32,
inner as u32,
total as u32,
);
unsafe {
stream
.launch_builder(&f)
.arg(out)
.arg(idx)
.arg(&value)
.arg(&outer_u)
.arg(&outdim_u)
.arg(&idxdim_u)
.arg(&inner_u)
.arg(&total_u)
.launch(cfg)?;
}
Ok(())
}
#[allow(clippy::too_many_arguments)]
fn launch_scatter_value_f64(
out: &mut CudaSlice<f64>,
idx: &CudaSlice<i64>,
value: f64,
outer: usize,
out_dim: usize,
idx_dim: usize,
inner: usize,
device: &GpuDevice,
) -> GpuResult<()> {
let total = outer
.checked_mul(idx_dim)
.and_then(|x| x.checked_mul(inner))
.ok_or(GpuError::LengthMismatch {
a: outer,
b: idx_dim,
})?;
if total == 0 {
return Ok(());
}
let stream = device.stream();
let ctx = device.context();
let f = get_or_compile(
ctx,
SCATTER_VALUE_DIM_F64_PTX,
"scatter_value_dim_f64_kernel",
device.ordinal() as u32,
)
.map_err(|e| GpuError::PtxCompileFailed {
kernel: "scatter_value_dim_f64_kernel",
source: e,
})?;
let cfg = launch_1d(total);
let (outer_u, outdim_u, idxdim_u, inner_u, total_u) = (
outer as u32,
out_dim as u32,
idx_dim as u32,
inner as u32,
total as u32,
);
unsafe {
stream
.launch_builder(&f)
.arg(out)
.arg(idx)
.arg(&value)
.arg(&outer_u)
.arg(&outdim_u)
.arg(&idxdim_u)
.arg(&inner_u)
.arg(&total_u)
.launch(cfg)?;
}
Ok(())
}
fn clone_f32(
input: &CudaBuffer<f32>,
len: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f32>> {
let mut out = alloc_zeros_f32(len, device)?;
if len > 0 {
let stream = device.stream();
stream.memcpy_dtod(&input.inner().slice(0..len), out.inner_mut())?;
}
Ok(out)
}
fn clone_f64(
input: &CudaBuffer<f64>,
len: usize,
device: &GpuDevice,
) -> GpuResult<CudaBuffer<f64>> {
let mut out = alloc_zeros_f64(len, device)?;
if len > 0 {
let stream = device.stream();
stream.memcpy_dtod(&input.inner().slice(0..len), out.inner_mut())?;
}
Ok(out)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::transfer::{cpu_to_gpu, gpu_to_cpu};
fn dev() -> GpuDevice {
GpuDevice::new(0).expect("cuda device")
}
fn htod_i64(d: &GpuDevice, v: &[i64]) -> CudaSlice<i64> {
d.stream().clone_htod(&v.to_vec()).expect("htod i64")
}
#[test]
fn gather_dim_f32_dim1() {
let d = dev();
let inp = cpu_to_gpu(&[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], &d).unwrap();
let idx = htod_i64(&d, &[0i64, 2, 1, 0]);
let out = gpu_gather_dim_f32(&inp, &idx, 2, 3, 2, 1, &d).unwrap();
assert_eq!(gpu_to_cpu(&out, &d).unwrap()[..4], [1.0f32, 3.0, 5.0, 4.0]);
}
#[test]
fn scatter_dim_f32_dim1() {
let d = dev();
let inp = cpu_to_gpu(&[0.0f32; 6], &d).unwrap();
let src = cpu_to_gpu(&[5.0f32, 6.0], &d).unwrap();
let idx = htod_i64(&d, &[2i64, 0]);
let out = gpu_scatter_dim_f32(&inp, &idx, &src, 2, 3, 1, 1, &d).unwrap();
assert_eq!(
gpu_to_cpu(&out, &d).unwrap()[..6],
[0.0f32, 0.0, 5.0, 6.0, 0.0, 0.0]
);
}
#[test]
fn scatter_value_dim_f32_1d() {
let d = dev();
let inp = cpu_to_gpu(&[0.0f32; 5], &d).unwrap();
let idx = htod_i64(&d, &[1i64, 3, 0]);
let out = gpu_scatter_value_dim_f32(&inp, &idx, 9.0, 1, 5, 3, 1, &d).unwrap();
assert_eq!(
gpu_to_cpu(&out, &d).unwrap()[..5],
[9.0f32, 9.0, 0.0, 9.0, 0.0]
);
}
#[test]
fn scatter_add_dim_f32_duplicate_indices() {
let d = dev();
let inp = cpu_to_gpu(&[1.0f32, 2.0, 3.0], &d).unwrap();
let src = cpu_to_gpu(&[10.0f32, 20.0, 30.0], &d).unwrap();
let idx = htod_i64(&d, &[0i64, 2, 0]);
let out = gpu_scatter_add_dim_f32(&inp, &idx, &src, 1, 3, 3, 1, &d).unwrap();
assert_eq!(gpu_to_cpu(&out, &d).unwrap()[..3], [41.0f32, 2.0, 23.0]);
}
#[test]
fn scatter_add_dim_f64_duplicate_indices() {
let d = dev();
let inp = cpu_to_gpu(&[1.0f64, 2.0, 3.0], &d).unwrap();
let src = cpu_to_gpu(&[10.0f64, 20.0, 30.0], &d).unwrap();
let idx = htod_i64(&d, &[0i64, 2, 0]);
let out = gpu_scatter_add_dim_f64(&inp, &idx, &src, 1, 3, 3, 1, &d).unwrap();
assert_eq!(gpu_to_cpu(&out, &d).unwrap()[..3], [41.0f64, 2.0, 23.0]);
}
#[test]
fn gather_dim_f64_dim0() {
let d = dev();
let inp = cpu_to_gpu(&[1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0], &d).unwrap();
let idx = htod_i64(&d, &[2i64, 0, 1, 1]);
let out = gpu_gather_dim_f64(&inp, &idx, 1, 3, 2, 2, &d).unwrap();
assert_eq!(gpu_to_cpu(&out, &d).unwrap()[..4], [5.0f64, 2.0, 3.0, 4.0]);
}
#[test]
fn scatter_add_segments_f32_basic() {
let d = dev();
let src = cpu_to_gpu(&[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], &d).unwrap();
let idx = htod_i64(&d, &[0i64, 1, 0]);
let out = gpu_scatter_add_segments_f32(&src, &idx, 3, 2, 2, &d).unwrap();
assert_eq!(gpu_to_cpu(&out, &d).unwrap()[..4], [6.0f32, 8.0, 3.0, 4.0]);
}
#[test]
fn scatter_add_segments_f32_duplicate_and_empty_row() {
let d = dev();
let src = cpu_to_gpu(&[1.0f32, 10.0, 2.0, 20.0, 3.0, 30.0, 4.0, 40.0], &d).unwrap();
let idx = htod_i64(&d, &[0i64, 0, 0, 0]);
let out = gpu_scatter_add_segments_f32(&src, &idx, 4, 2, 2, &d).unwrap();
let got = gpu_to_cpu(&out, &d).unwrap();
assert_eq!(got[..4], [10.0f32, 100.0, 0.0, 0.0]);
}
#[test]
fn scatter_add_segments_f64_basic() {
let d = dev();
let src = cpu_to_gpu(&[1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0], &d).unwrap();
let idx = htod_i64(&d, &[0i64, 1, 0]);
let out = gpu_scatter_add_segments_f64(&src, &idx, 3, 2, 2, &d).unwrap();
assert_eq!(gpu_to_cpu(&out, &d).unwrap()[..4], [6.0f64, 8.0, 3.0, 4.0]);
}
}