#![cfg(feature = "cuda")]
use cudarc::driver::{CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg, ValidAsZeroBits};
use crate::device::GpuDevice;
use crate::error::{GpuError, GpuResult};
use crate::module_cache::get_or_compile;
const BLOCK_SIZE: u32 = 256;
fn launch_1d(n: usize) -> LaunchConfig {
let grid = ((n as u32).saturating_add(BLOCK_SIZE - 1)) / BLOCK_SIZE;
LaunchConfig {
grid_dim: (grid.max(1), 1, 1),
block_dim: (BLOCK_SIZE, 1, 1),
shared_mem_bytes: 0,
}
}
const ARG_MAX: u32 = 0;
const ARG_MIN: u32 = 1;
const ARGREDUCE_F32_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry argreduce_f32_kernel(
.param .u64 in_ptr, .param .u64 out_ptr,
.param .u32 outer, .param .u32 dim_size, .param .u32 inner,
.param .u32 total, .param .u32 op
) {
.reg .u32 %gtid, %bid, %bdim, %tot, %dim, %inn, %op_r;
.reg .u32 %oidx, %iidx, %base, %j, %elem, %best_j;
.reg .u64 %in, %out, %off, %addr;
.reg .f32 %v, %acc;
.reg .s64 %best_s64;
.reg .pred %p, %is_max, %not_max, %better, %lt, %gt;
ld.param.u64 %in, [in_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %tot, [total];
ld.param.u32 %dim, [dim_size];
ld.param.u32 %inn, [inner];
ld.param.u32 %op_r, [op];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %gtid, %tid.x;
mad.lo.u32 %gtid, %bid, %bdim, %gtid;
setp.ge.u32 %p, %gtid, %tot;
@%p bra DONE;
setp.eq.u32 %is_max, %op_r, 0;
not.pred %not_max, %is_max;
div.u32 %oidx, %gtid, %inn;
rem.u32 %iidx, %gtid, %inn;
// base = (oidx * dim) * inn + iidx
mul.lo.u32 %base, %oidx, %dim;
mul.lo.u32 %base, %base, %inn;
add.u32 %base, %base, %iidx;
// seed with element j=0
cvt.u64.u32 %off, %base;
shl.b64 %off, %off, 2;
add.u64 %addr, %in, %off;
ld.global.f32 %acc, [%addr];
mov.u32 %best_j, 0;
mov.u32 %j, 1;
LOOP:
setp.ge.u32 %p, %j, %dim;
@%p bra STORE;
// elem = base + j*inn
mul.lo.u32 %elem, %j, %inn;
add.u32 %elem, %elem, %base;
cvt.u64.u32 %off, %elem;
shl.b64 %off, %off, 2;
add.u64 %addr, %in, %off;
ld.global.f32 %v, [%addr];
// strict compare: argmax keeps first-greatest, argmin first-least
setp.gt.f32 %gt, %v, %acc;
setp.lt.f32 %lt, %v, %acc;
and.pred %gt, %gt, %is_max;
and.pred %lt, %lt, %not_max;
or.pred %better, %gt, %lt;
@%better mov.f32 %acc, %v;
@%better mov.u32 %best_j, %j;
add.u32 %j, %j, 1;
bra LOOP;
STORE:
cvt.s64.u32 %best_s64, %best_j;
cvt.u64.u32 %off, %gtid;
shl.b64 %off, %off, 3;
add.u64 %addr, %out, %off;
st.global.s64 [%addr], %best_s64;
DONE:
ret;
}
";
const ARGREDUCE_F64_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry argreduce_f64_kernel(
.param .u64 in_ptr, .param .u64 out_ptr,
.param .u32 outer, .param .u32 dim_size, .param .u32 inner,
.param .u32 total, .param .u32 op
) {
.reg .u32 %gtid, %bid, %bdim, %tot, %dim, %inn, %op_r;
.reg .u32 %oidx, %iidx, %base, %j, %elem, %best_j;
.reg .u64 %in, %out, %off, %addr;
.reg .f64 %v, %acc;
.reg .s64 %best_s64;
.reg .pred %p, %is_max, %not_max, %better, %lt, %gt;
ld.param.u64 %in, [in_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %tot, [total];
ld.param.u32 %dim, [dim_size];
ld.param.u32 %inn, [inner];
ld.param.u32 %op_r, [op];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %gtid, %tid.x;
mad.lo.u32 %gtid, %bid, %bdim, %gtid;
setp.ge.u32 %p, %gtid, %tot;
@%p bra DONE;
setp.eq.u32 %is_max, %op_r, 0;
not.pred %not_max, %is_max;
div.u32 %oidx, %gtid, %inn;
rem.u32 %iidx, %gtid, %inn;
mul.lo.u32 %base, %oidx, %dim;
mul.lo.u32 %base, %base, %inn;
add.u32 %base, %base, %iidx;
cvt.u64.u32 %off, %base;
shl.b64 %off, %off, 3;
add.u64 %addr, %in, %off;
ld.global.f64 %acc, [%addr];
mov.u32 %best_j, 0;
mov.u32 %j, 1;
LOOP:
setp.ge.u32 %p, %j, %dim;
@%p bra STORE;
mul.lo.u32 %elem, %j, %inn;
add.u32 %elem, %elem, %base;
cvt.u64.u32 %off, %elem;
shl.b64 %off, %off, 3;
add.u64 %addr, %in, %off;
ld.global.f64 %v, [%addr];
setp.gt.f64 %gt, %v, %acc;
setp.lt.f64 %lt, %v, %acc;
and.pred %gt, %gt, %is_max;
and.pred %lt, %lt, %not_max;
or.pred %better, %gt, %lt;
@%better mov.f64 %acc, %v;
@%better mov.u32 %best_j, %j;
add.u32 %j, %j, 1;
bra LOOP;
STORE:
cvt.s64.u32 %best_s64, %best_j;
cvt.u64.u32 %off, %gtid;
shl.b64 %off, %off, 3;
add.u64 %addr, %out, %off;
st.global.s64 [%addr], %best_s64;
DONE:
ret;
}
";
const ARGREDUCE_I32_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry argreduce_i32_kernel(
.param .u64 in_ptr, .param .u64 out_ptr,
.param .u32 outer, .param .u32 dim_size, .param .u32 inner,
.param .u32 total, .param .u32 op
) {
.reg .u32 %gtid, %bid, %bdim, %tot, %dim, %inn, %op_r;
.reg .u32 %oidx, %iidx, %base, %j, %elem, %best_j;
.reg .u64 %in, %out, %off, %addr;
.reg .s32 %v, %acc;
.reg .s64 %best_s64;
.reg .pred %p, %is_max, %not_max, %better, %lt, %gt;
ld.param.u64 %in, [in_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %tot, [total];
ld.param.u32 %dim, [dim_size];
ld.param.u32 %inn, [inner];
ld.param.u32 %op_r, [op];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %gtid, %tid.x;
mad.lo.u32 %gtid, %bid, %bdim, %gtid;
setp.ge.u32 %p, %gtid, %tot;
@%p bra DONE;
setp.eq.u32 %is_max, %op_r, 0;
not.pred %not_max, %is_max;
div.u32 %oidx, %gtid, %inn;
rem.u32 %iidx, %gtid, %inn;
mul.lo.u32 %base, %oidx, %dim;
mul.lo.u32 %base, %base, %inn;
add.u32 %base, %base, %iidx;
cvt.u64.u32 %off, %base;
shl.b64 %off, %off, 2;
add.u64 %addr, %in, %off;
ld.global.s32 %acc, [%addr];
mov.u32 %best_j, 0;
mov.u32 %j, 1;
LOOP:
setp.ge.u32 %p, %j, %dim;
@%p bra STORE;
mul.lo.u32 %elem, %j, %inn;
add.u32 %elem, %elem, %base;
cvt.u64.u32 %off, %elem;
shl.b64 %off, %off, 2;
add.u64 %addr, %in, %off;
ld.global.s32 %v, [%addr];
setp.gt.s32 %gt, %v, %acc;
setp.lt.s32 %lt, %v, %acc;
and.pred %gt, %gt, %is_max;
and.pred %lt, %lt, %not_max;
or.pred %better, %gt, %lt;
@%better mov.s32 %acc, %v;
@%better mov.u32 %best_j, %j;
add.u32 %j, %j, 1;
bra LOOP;
STORE:
cvt.s64.u32 %best_s64, %best_j;
cvt.u64.u32 %off, %gtid;
shl.b64 %off, %off, 3;
add.u64 %addr, %out, %off;
st.global.s64 [%addr], %best_s64;
DONE:
ret;
}
";
const ARGREDUCE_I64_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry argreduce_i64_kernel(
.param .u64 in_ptr, .param .u64 out_ptr,
.param .u32 outer, .param .u32 dim_size, .param .u32 inner,
.param .u32 total, .param .u32 op
) {
.reg .u32 %gtid, %bid, %bdim, %tot, %dim, %inn, %op_r;
.reg .u32 %oidx, %iidx, %base, %j, %elem, %best_j;
.reg .u64 %in, %out, %off, %addr;
.reg .s64 %v, %acc, %best_s64;
.reg .pred %p, %is_max, %not_max, %better, %lt, %gt;
ld.param.u64 %in, [in_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %tot, [total];
ld.param.u32 %dim, [dim_size];
ld.param.u32 %inn, [inner];
ld.param.u32 %op_r, [op];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %gtid, %tid.x;
mad.lo.u32 %gtid, %bid, %bdim, %gtid;
setp.ge.u32 %p, %gtid, %tot;
@%p bra DONE;
setp.eq.u32 %is_max, %op_r, 0;
not.pred %not_max, %is_max;
div.u32 %oidx, %gtid, %inn;
rem.u32 %iidx, %gtid, %inn;
mul.lo.u32 %base, %oidx, %dim;
mul.lo.u32 %base, %base, %inn;
add.u32 %base, %base, %iidx;
cvt.u64.u32 %off, %base;
shl.b64 %off, %off, 3;
add.u64 %addr, %in, %off;
ld.global.s64 %acc, [%addr];
mov.u32 %best_j, 0;
mov.u32 %j, 1;
LOOP:
setp.ge.u32 %p, %j, %dim;
@%p bra STORE;
mul.lo.u32 %elem, %j, %inn;
add.u32 %elem, %elem, %base;
cvt.u64.u32 %off, %elem;
shl.b64 %off, %off, 3;
add.u64 %addr, %in, %off;
ld.global.s64 %v, [%addr];
setp.gt.s64 %gt, %v, %acc;
setp.lt.s64 %lt, %v, %acc;
and.pred %gt, %gt, %is_max;
and.pred %lt, %lt, %not_max;
or.pred %better, %gt, %lt;
@%better mov.s64 %acc, %v;
@%better mov.u32 %best_j, %j;
add.u32 %j, %j, 1;
bra LOOP;
STORE:
cvt.s64.u32 %best_s64, %best_j;
cvt.u64.u32 %off, %gtid;
shl.b64 %off, %off, 3;
add.u64 %addr, %out, %off;
st.global.s64 [%addr], %best_s64;
DONE:
ret;
}
";
const ARGREDUCE_F16_PTX: &str = "\
.version 7.0
.target sm_53
.address_size 64
.visible .entry argreduce_f16_kernel(
.param .u64 in_ptr, .param .u64 out_ptr,
.param .u32 outer, .param .u32 dim_size, .param .u32 inner,
.param .u32 total, .param .u32 op
) {
.reg .u32 %gtid, %bid, %bdim, %tot, %dim, %inn, %op_r;
.reg .u32 %oidx, %iidx, %base, %j, %elem, %best_j;
.reg .u64 %in, %out, %off, %addr;
.reg .b16 %h;
.reg .f32 %v, %acc;
.reg .s64 %best_s64;
.reg .pred %p, %is_max, %not_max, %better, %lt, %gt;
ld.param.u64 %in, [in_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %tot, [total];
ld.param.u32 %dim, [dim_size];
ld.param.u32 %inn, [inner];
ld.param.u32 %op_r, [op];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %gtid, %tid.x;
mad.lo.u32 %gtid, %bid, %bdim, %gtid;
setp.ge.u32 %p, %gtid, %tot;
@%p bra DONE;
setp.eq.u32 %is_max, %op_r, 0;
not.pred %not_max, %is_max;
div.u32 %oidx, %gtid, %inn;
rem.u32 %iidx, %gtid, %inn;
mul.lo.u32 %base, %oidx, %dim;
mul.lo.u32 %base, %base, %inn;
add.u32 %base, %base, %iidx;
cvt.u64.u32 %off, %base;
shl.b64 %off, %off, 1;
add.u64 %addr, %in, %off;
ld.global.b16 %h, [%addr];
cvt.f32.f16 %acc, %h;
mov.u32 %best_j, 0;
mov.u32 %j, 1;
LOOP:
setp.ge.u32 %p, %j, %dim;
@%p bra STORE;
mul.lo.u32 %elem, %j, %inn;
add.u32 %elem, %elem, %base;
cvt.u64.u32 %off, %elem;
shl.b64 %off, %off, 1;
add.u64 %addr, %in, %off;
ld.global.b16 %h, [%addr];
cvt.f32.f16 %v, %h;
setp.gt.f32 %gt, %v, %acc;
setp.lt.f32 %lt, %v, %acc;
and.pred %gt, %gt, %is_max;
and.pred %lt, %lt, %not_max;
or.pred %better, %gt, %lt;
@%better mov.f32 %acc, %v;
@%better mov.u32 %best_j, %j;
add.u32 %j, %j, 1;
bra LOOP;
STORE:
cvt.s64.u32 %best_s64, %best_j;
cvt.u64.u32 %off, %gtid;
shl.b64 %off, %off, 3;
add.u64 %addr, %out, %off;
st.global.s64 [%addr], %best_s64;
DONE:
ret;
}
";
const ARGREDUCE_BF16_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry argreduce_bf16_kernel(
.param .u64 in_ptr, .param .u64 out_ptr,
.param .u32 outer, .param .u32 dim_size, .param .u32 inner,
.param .u32 total, .param .u32 op
) {
.reg .u32 %gtid, %bid, %bdim, %tot, %dim, %inn, %op_r;
.reg .u32 %oidx, %iidx, %base, %j, %elem, %best_j, %bits;
.reg .u16 %h;
.reg .u64 %in, %out, %off, %addr;
.reg .f32 %v, %acc;
.reg .s64 %best_s64;
.reg .pred %p, %is_max, %not_max, %better, %lt, %gt;
ld.param.u64 %in, [in_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %tot, [total];
ld.param.u32 %dim, [dim_size];
ld.param.u32 %inn, [inner];
ld.param.u32 %op_r, [op];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %gtid, %tid.x;
mad.lo.u32 %gtid, %bid, %bdim, %gtid;
setp.ge.u32 %p, %gtid, %tot;
@%p bra DONE;
setp.eq.u32 %is_max, %op_r, 0;
not.pred %not_max, %is_max;
div.u32 %oidx, %gtid, %inn;
rem.u32 %iidx, %gtid, %inn;
mul.lo.u32 %base, %oidx, %dim;
mul.lo.u32 %base, %base, %inn;
add.u32 %base, %base, %iidx;
cvt.u64.u32 %off, %base;
shl.b64 %off, %off, 1;
add.u64 %addr, %in, %off;
ld.global.u16 %h, [%addr];
cvt.u32.u16 %bits, %h;
shl.b32 %bits, %bits, 16;
mov.b32 %acc, %bits;
mov.u32 %best_j, 0;
mov.u32 %j, 1;
LOOP:
setp.ge.u32 %p, %j, %dim;
@%p bra STORE;
mul.lo.u32 %elem, %j, %inn;
add.u32 %elem, %elem, %base;
cvt.u64.u32 %off, %elem;
shl.b64 %off, %off, 1;
add.u64 %addr, %in, %off;
ld.global.u16 %h, [%addr];
cvt.u32.u16 %bits, %h;
shl.b32 %bits, %bits, 16;
mov.b32 %v, %bits;
setp.gt.f32 %gt, %v, %acc;
setp.lt.f32 %lt, %v, %acc;
and.pred %gt, %gt, %is_max;
and.pred %lt, %lt, %not_max;
or.pred %better, %gt, %lt;
@%better mov.f32 %acc, %v;
@%better mov.u32 %best_j, %j;
add.u32 %j, %j, 1;
bra LOOP;
STORE:
cvt.s64.u32 %best_s64, %best_j;
cvt.u64.u32 %off, %gtid;
shl.b64 %off, %off, 3;
add.u64 %addr, %out, %off;
st.global.s64 [%addr], %best_s64;
DONE:
ret;
}
";
#[allow(clippy::too_many_arguments)]
fn launch_argreduce<V: DeviceRepr + ValidAsZeroBits>(
in_slice: &CudaSlice<V>,
outer: usize,
dim_size: usize,
inner: usize,
device: &GpuDevice,
ptx: &'static str,
kernel_name: &'static str,
op: u32,
) -> GpuResult<CudaSlice<i64>> {
let total = outer
.checked_mul(inner)
.ok_or(GpuError::LengthMismatch { a: outer, b: inner })?;
let expect = outer
.checked_mul(dim_size)
.and_then(|x| x.checked_mul(inner))
.ok_or(GpuError::LengthMismatch {
a: outer,
b: dim_size,
})?;
if in_slice.len() < expect {
return Err(GpuError::LengthMismatch {
a: in_slice.len(),
b: expect,
});
}
let stream = device.stream();
if total == 0 || dim_size == 0 {
return Ok(stream.alloc_zeros::<i64>(total)?);
}
let ctx = device.context();
let f = get_or_compile(ctx, ptx, kernel_name, device.ordinal() as u32).map_err(|e| {
GpuError::PtxCompileFailed {
kernel: kernel_name,
source: e,
}
})?;
let mut out = stream.alloc_zeros::<i64>(total)?;
let cfg = launch_1d(total);
let (outer_u, dim_u, inner_u, total_u) =
(outer as u32, dim_size as u32, inner as u32, total as u32);
unsafe {
stream
.launch_builder(&f)
.arg(in_slice)
.arg(&mut out)
.arg(&outer_u)
.arg(&dim_u)
.arg(&inner_u)
.arg(&total_u)
.arg(&op)
.launch(cfg)?;
}
Ok(out)
}
macro_rules! arg_entry {
($name:ident, $ty:ty, $ptx:ident, $kname:literal, $op:expr) => {
#[doc = concat!("`", stringify!($name), "` over a ", stringify!($ty), " value buffer.")]
pub fn $name(
input: &CudaSlice<$ty>,
outer: usize,
dim_size: usize,
inner: usize,
d: &GpuDevice,
) -> GpuResult<CudaSlice<i64>> {
launch_argreduce(input, outer, dim_size, inner, d, $ptx, $kname, $op)
}
};
}
arg_entry!(
gpu_argmax_f32,
f32,
ARGREDUCE_F32_PTX,
"argreduce_f32_kernel",
ARG_MAX
);
arg_entry!(
gpu_argmin_f32,
f32,
ARGREDUCE_F32_PTX,
"argreduce_f32_kernel",
ARG_MIN
);
arg_entry!(
gpu_argmax_f64,
f64,
ARGREDUCE_F64_PTX,
"argreduce_f64_kernel",
ARG_MAX
);
arg_entry!(
gpu_argmin_f64,
f64,
ARGREDUCE_F64_PTX,
"argreduce_f64_kernel",
ARG_MIN
);
arg_entry!(
gpu_argmax_i32,
i32,
ARGREDUCE_I32_PTX,
"argreduce_i32_kernel",
ARG_MAX
);
arg_entry!(
gpu_argmin_i32,
i32,
ARGREDUCE_I32_PTX,
"argreduce_i32_kernel",
ARG_MIN
);
arg_entry!(
gpu_argmax_i64,
i64,
ARGREDUCE_I64_PTX,
"argreduce_i64_kernel",
ARG_MAX
);
arg_entry!(
gpu_argmin_i64,
i64,
ARGREDUCE_I64_PTX,
"argreduce_i64_kernel",
ARG_MIN
);
pub fn gpu_argmax_f16(
input: &CudaSlice<u16>,
outer: usize,
dim_size: usize,
inner: usize,
d: &GpuDevice,
) -> GpuResult<CudaSlice<i64>> {
launch_argreduce(
input,
outer,
dim_size,
inner,
d,
ARGREDUCE_F16_PTX,
"argreduce_f16_kernel",
ARG_MAX,
)
}
pub fn gpu_argmin_f16(
input: &CudaSlice<u16>,
outer: usize,
dim_size: usize,
inner: usize,
d: &GpuDevice,
) -> GpuResult<CudaSlice<i64>> {
launch_argreduce(
input,
outer,
dim_size,
inner,
d,
ARGREDUCE_F16_PTX,
"argreduce_f16_kernel",
ARG_MIN,
)
}
pub fn gpu_argmax_bf16(
input: &CudaSlice<u16>,
outer: usize,
dim_size: usize,
inner: usize,
d: &GpuDevice,
) -> GpuResult<CudaSlice<i64>> {
launch_argreduce(
input,
outer,
dim_size,
inner,
d,
ARGREDUCE_BF16_PTX,
"argreduce_bf16_kernel",
ARG_MAX,
)
}
pub fn gpu_argmin_bf16(
input: &CudaSlice<u16>,
outer: usize,
dim_size: usize,
inner: usize,
d: &GpuDevice,
) -> GpuResult<CudaSlice<i64>> {
launch_argreduce(
input,
outer,
dim_size,
inner,
d,
ARGREDUCE_BF16_PTX,
"argreduce_bf16_kernel",
ARG_MIN,
)
}
#[cfg(test)]
mod tests {
use super::*;
fn dev() -> GpuDevice {
GpuDevice::new(0).expect("cuda device")
}
#[test]
fn argmax_argmin_f32_global() {
let d = dev();
let h = d
.stream()
.clone_htod(&vec![3.0f32, 1.0, 4.0, 1.0, 5.0, 9.0, 2.0])
.unwrap();
let mx = gpu_argmax_f32(&h, 1, 7, 1, &d).unwrap();
let mn = gpu_argmin_f32(&h, 1, 7, 1, &d).unwrap();
assert_eq!(d.stream().clone_dtoh(&mx).unwrap(), vec![5i64]);
assert_eq!(d.stream().clone_dtoh(&mn).unwrap(), vec![1i64]);
}
#[test]
fn argmax_f32_tie_first_index() {
let d = dev();
let h = d.stream().clone_htod(&vec![5.0f32, 1.0, 2.0, 5.0]).unwrap();
let mx = gpu_argmax_f32(&h, 1, 4, 1, &d).unwrap();
assert_eq!(d.stream().clone_dtoh(&mx).unwrap(), vec![0i64]);
}
#[test]
fn argmax_f32_along_dim() {
let d = dev();
let h = d
.stream()
.clone_htod(&vec![1.0f32, 9.0, 2.0, 7.0, 3.0, 4.0])
.unwrap();
let mx = gpu_argmax_f32(&h, 2, 3, 1, &d).unwrap();
assert_eq!(d.stream().clone_dtoh(&mx).unwrap(), vec![1i64, 0i64]);
}
#[test]
fn argmax_along_dim0_inner() {
let d = dev();
let h = d
.stream()
.clone_htod(&vec![1.0f32, 9.0, 2.0, 7.0, 3.0, 4.0])
.unwrap();
let mx = gpu_argmax_f32(&h, 1, 2, 3, &d).unwrap();
assert_eq!(d.stream().clone_dtoh(&mx).unwrap(), vec![1i64, 0i64, 1i64]);
}
#[test]
fn argmax_i32_and_i64() {
let d = dev();
let hi = d.stream().clone_htod(&vec![-3i32, 7, 7, 2]).unwrap();
let mx = gpu_argmax_i32(&hi, 1, 4, 1, &d).unwrap();
assert_eq!(d.stream().clone_dtoh(&mx).unwrap(), vec![1i64]); let hl = d.stream().clone_htod(&vec![10i64, -5, 100, 100]).unwrap();
let mn = gpu_argmin_i64(&hl, 1, 4, 1, &d).unwrap();
assert_eq!(d.stream().clone_dtoh(&mn).unwrap(), vec![1i64]);
}
#[test]
fn argmax_f16_bf16() {
let d = dev();
let f16bits: Vec<u16> = [1.0f32, 5.0, 2.0]
.iter()
.map(|&v| half::f16::from_f32(v).to_bits())
.collect();
let h16 = d.stream().clone_htod(&f16bits).unwrap();
let mx = gpu_argmax_f16(&h16, 1, 3, 1, &d).unwrap();
assert_eq!(d.stream().clone_dtoh(&mx).unwrap(), vec![1i64]);
let bf16bits: Vec<u16> = [1.0f32, 2.0, 8.0]
.iter()
.map(|&v| half::bf16::from_f32(v).to_bits())
.collect();
let hb = d.stream().clone_htod(&bf16bits).unwrap();
let mx2 = gpu_argmax_bf16(&hb, 1, 3, 1, &d).unwrap();
assert_eq!(d.stream().clone_dtoh(&mx2).unwrap(), vec![2i64]);
}
}