#![cfg(feature = "cuda")]
use cudarc::driver::{CudaSlice, DeviceRepr, LaunchConfig, PushKernelArg, ValidAsZeroBits};
use crate::device::GpuDevice;
use crate::error::{GpuError, GpuResult};
use crate::module_cache::{get_or_compile, get_or_compile_owned};
const BLOCK_SIZE: u32 = 256;
fn launch_1d(n: usize) -> LaunchConfig {
let grid = ((n as u32).saturating_add(BLOCK_SIZE - 1)) / BLOCK_SIZE;
LaunchConfig {
grid_dim: (grid.max(1), 1, 1),
block_dim: (BLOCK_SIZE, 1, 1),
shared_mem_bytes: 0,
}
}
fn cmp_ptx(
kernel_name: &str,
in_shift: u32, load_ty: &str, reg_decl: &str, setp: &str, ) -> String {
format!(
"\
.version 7.0
.target sm_52
.address_size 64
.visible .entry {kernel_name}(
.param .u64 a_ptr, .param .u64 b_ptr, .param .u64 out_ptr, .param .u32 n
) {{
.reg .u32 %idx, %bid, %bdim, %nr;
.reg .u64 %a, %b, %out, %ioff, %ooff;
{reg_decl}
.reg .u16 %res;
.reg .pred %p, %c;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %b, [b_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %nr, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %idx, %tid.x;
mad.lo.u32 %idx, %bid, %bdim, %idx;
setp.ge.u32 %p, %idx, %nr;
@%p bra DONE;
cvt.u64.u32 %ioff, %idx;
shl.b64 %ioff, %ioff, {in_shift};
add.u64 %a, %a, %ioff;
add.u64 %b, %b, %ioff;
// output is 1 byte per element: out_off = idx
cvt.u64.u32 %ooff, %idx;
add.u64 %out, %out, %ooff;
ld.global.{load_ty} %va, [%a];
ld.global.{load_ty} %vb, [%b];
{setp}
selp.u16 %res, 1, 0, %c;
st.global.u8 [%out], %res;
DONE:
ret;
}}
"
)
}
fn cmp_half_ptx(
kernel_name: &str,
target: &str, decode: &str, setp: &str, ) -> String {
format!(
"\
.version 7.0
.target {target}
.address_size 64
.visible .entry {kernel_name}(
.param .u64 a_ptr, .param .u64 b_ptr, .param .u64 out_ptr, .param .u32 n
) {{
.reg .u32 %idx, %bid, %bdim, %nr;
.reg .u64 %a, %b, %out, %ioff, %ooff;
.reg .b16 %ha, %hb, %zero16;
.reg .b32 %ua, %ub;
.reg .u16 %res;
.reg .f32 %fa, %fb;
.reg .pred %p, %c;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %b, [b_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %nr, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %idx, %tid.x;
mad.lo.u32 %idx, %bid, %bdim, %idx;
setp.ge.u32 %p, %idx, %nr;
@%p bra DONE;
cvt.u64.u32 %ioff, %idx;
shl.b64 %ioff, %ioff, 1;
add.u64 %a, %a, %ioff;
add.u64 %b, %b, %ioff;
cvt.u64.u32 %ooff, %idx;
add.u64 %out, %out, %ooff;
mov.b16 %zero16, 0;
ld.global.b16 %ha, [%a];
ld.global.b16 %hb, [%b];
{decode}
{setp}
selp.u16 %res, 1, 0, %c;
st.global.u8 [%out], %res;
DONE:
ret;
}}
"
)
}
const BF16_DECODE: &str = "\
mov.b32 %ua, {%zero16, %ha}; mov.b32 %fa, %ua;
mov.b32 %ub, {%zero16, %hb}; mov.b32 %fb, %ub;";
const F16_DECODE: &str = "\
cvt.f32.f16 %fa, %ha;
cvt.f32.f16 %fb, %hb;";
fn logic_bin_ptx(kernel_name: &str, op: &str ) -> String {
format!(
"\
.version 7.0
.target sm_52
.address_size 64
.visible .entry {kernel_name}(
.param .u64 a_ptr, .param .u64 b_ptr, .param .u64 out_ptr, .param .u32 n
) {{
.reg .u32 %idx, %bid, %bdim, %nr;
.reg .u64 %a, %b, %out, %off;
.reg .u16 %va, %vb, %res;
.reg .pred %pa, %pb, %pr, %p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %b, [b_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %nr, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %idx, %tid.x;
mad.lo.u32 %idx, %bid, %bdim, %idx;
setp.ge.u32 %p, %idx, %nr;
@%p bra DONE;
cvt.u64.u32 %off, %idx;
add.u64 %a, %a, %off;
add.u64 %b, %b, %off;
add.u64 %out, %out, %off;
ld.global.u8 %va, [%a];
ld.global.u8 %vb, [%b];
setp.ne.u16 %pa, %va, 0;
setp.ne.u16 %pb, %vb, 0;
{op}.pred %pr, %pa, %pb;
selp.u16 %res, 1, 0, %pr;
st.global.u8 [%out], %res;
DONE:
ret;
}}
"
)
}
const NOT_BOOL_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry not_bool_kernel(
.param .u64 a_ptr, .param .u64 out_ptr, .param .u32 n
) {
.reg .u32 %idx, %bid, %bdim, %nr;
.reg .u64 %a, %out, %off;
.reg .u16 %va, %res;
.reg .pred %pa, %p;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %nr, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %idx, %tid.x;
mad.lo.u32 %idx, %bid, %bdim, %idx;
setp.ge.u32 %p, %idx, %nr;
@%p bra DONE;
cvt.u64.u32 %off, %idx;
add.u64 %a, %a, %off;
add.u64 %out, %out, %off;
ld.global.u8 %va, [%a];
// res = (va == 0) ? 1 : 0
setp.eq.u16 %pa, %va, 0;
selp.u16 %res, 1, 0, %pa;
st.global.u8 [%out], %res;
DONE:
ret;
}
";
const REDUCE_BOOL_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry reduce_bool_kernel(
.param .u64 a_ptr, .param .u64 out_ptr, .param .u32 n, .param .u32 op
) {
.reg .u32 %idx, %bid, %bdim, %nr, %op_r, %i;
.reg .u64 %a, %out, %off, %cur;
.reg .u16 %acc, %v, %vn;
.reg .pred %only0, %p, %is_any, %pacc, %pv;
ld.param.u64 %a, [a_ptr];
ld.param.u64 %out, [out_ptr];
ld.param.u32 %nr, [n];
ld.param.u32 %op_r, [op];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %idx, %tid.x;
mad.lo.u32 %idx, %bid, %bdim, %idx;
setp.ne.u32 %only0, %idx, 0;
@%only0 bra DONE;
setp.eq.u32 %is_any, %op_r, 0;
// Initialise accumulator from a[0] normalised to 0/1 (n >= 1 guaranteed).
ld.global.u8 %v, [%a];
setp.ne.u16 %pv, %v, 0;
selp.u16 %acc, 1, 0, %pv;
mov.u32 %i, 1;
LOOP:
setp.ge.u32 %p, %i, %nr;
@%p bra STORE;
cvt.u64.u32 %off, %i;
add.u64 %cur, %a, %off;
ld.global.u8 %v, [%cur];
setp.ne.u16 %pv, %v, 0;
selp.u16 %vn, 1, 0, %pv;
setp.ne.u16 %pacc, %acc, 0;
setp.ne.u16 %pv, %vn, 0;
// any: acc = acc OR v ; all: acc = acc AND v
@%is_any or.pred %pacc, %pacc, %pv;
@!%is_any and.pred %pacc, %pacc, %pv;
selp.u16 %acc, 1, 0, %pacc;
add.u32 %i, %i, 1;
bra LOOP;
STORE:
st.global.u8 [%out], %acc;
DONE:
ret;
}
";
const REDUCE_ANY: u32 = 0;
const REDUCE_ALL: u32 = 1;
fn launch_cmp<T: DeviceRepr + ValidAsZeroBits>(
a: &CudaSlice<T>,
b: &CudaSlice<T>,
n: usize,
device: &GpuDevice,
ptx: String,
kernel_name: String,
err_label: &'static str,
) -> GpuResult<CudaSlice<u8>> {
if a.len() < n || b.len() < n {
return Err(GpuError::LengthMismatch {
a: a.len().min(b.len()),
b: n,
});
}
let stream = device.stream();
if n == 0 {
return Ok(stream.alloc_zeros::<u8>(0)?);
}
let ctx = device.context();
let f = get_or_compile_owned(ctx, ptx, kernel_name, device.ordinal() as u32).map_err(|e| {
GpuError::PtxCompileFailed {
kernel: err_label,
source: e,
}
})?;
let mut out = stream.alloc_zeros::<u8>(n)?;
let cfg = launch_1d(n);
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(a)
.arg(b)
.arg(&mut out)
.arg(&n_u32)
.launch(cfg)?;
}
Ok(out)
}
fn launch_cmp_half(
a: &CudaSlice<u16>,
b: &CudaSlice<u16>,
n: usize,
device: &GpuDevice,
ptx: String,
kernel_name: String,
) -> GpuResult<CudaSlice<u8>> {
launch_cmp::<u16>(a, b, n, device, ptx, kernel_name, "cmp_half")
}
fn launch_logic_bin(
a: &CudaSlice<u8>,
b: &CudaSlice<u8>,
device: &GpuDevice,
ptx: String,
kernel_name: &'static str,
) -> GpuResult<CudaSlice<u8>> {
if a.len() != b.len() {
return Err(GpuError::LengthMismatch {
a: a.len(),
b: b.len(),
});
}
launch_cmp::<u8>(
a,
b,
a.len(),
device,
ptx,
kernel_name.to_string(),
kernel_name,
)
}
fn launch_not(a: &CudaSlice<u8>, device: &GpuDevice) -> GpuResult<CudaSlice<u8>> {
let n = a.len();
let stream = device.stream();
if n == 0 {
return Ok(stream.alloc_zeros::<u8>(0)?);
}
let ctx = device.context();
let f = get_or_compile(
ctx,
NOT_BOOL_PTX,
"not_bool_kernel",
device.ordinal() as u32,
)
.map_err(|e| GpuError::PtxCompileFailed {
kernel: "not_bool_kernel",
source: e,
})?;
let mut out = stream.alloc_zeros::<u8>(n)?;
let cfg = launch_1d(n);
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(a)
.arg(&mut out)
.arg(&n_u32)
.launch(cfg)?;
}
Ok(out)
}
fn launch_reduce_bool(
a: &CudaSlice<u8>,
device: &GpuDevice,
op: u32,
empty_identity: u8,
) -> GpuResult<CudaSlice<u8>> {
let n = a.len();
let stream = device.stream();
if n == 0 {
let host = [empty_identity];
return Ok(stream.clone_htod(&host)?);
}
let ctx = device.context();
let f = get_or_compile(
ctx,
REDUCE_BOOL_PTX,
"reduce_bool_kernel",
device.ordinal() as u32,
)
.map_err(|e| GpuError::PtxCompileFailed {
kernel: "reduce_bool_kernel",
source: e,
})?;
let mut out = stream.alloc_zeros::<u8>(1)?;
let cfg = LaunchConfig {
grid_dim: (1, 1, 1),
block_dim: (1, 1, 1),
shared_mem_bytes: 0,
};
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(a)
.arg(&mut out)
.arg(&n_u32)
.arg(&op)
.launch(cfg)?;
}
Ok(out)
}
fn setp_for(op: &str, ty: &str) -> String {
format!("setp.{op}.{ty} %c, %va, %vb;")
}
pub fn gpu_cmp_f32(
a: &CudaSlice<f32>,
b: &CudaSlice<f32>,
n: usize,
op: &str,
d: &GpuDevice,
) -> GpuResult<CudaSlice<u8>> {
let name = format!("cmp_{op}_f32_kernel");
let ptx = cmp_ptx(&name, 2, "f32", ".reg .f32 %va, %vb;", &setp_for(op, "f32"));
launch_cmp::<f32>(a, b, n, d, ptx, name, "cmp_f32")
}
pub fn gpu_cmp_f64(
a: &CudaSlice<f64>,
b: &CudaSlice<f64>,
n: usize,
op: &str,
d: &GpuDevice,
) -> GpuResult<CudaSlice<u8>> {
let name = format!("cmp_{op}_f64_kernel");
let ptx = cmp_ptx(&name, 3, "f64", ".reg .f64 %va, %vb;", &setp_for(op, "f64"));
launch_cmp::<f64>(a, b, n, d, ptx, name, "cmp_f64")
}
pub fn gpu_cmp_i32(
a: &CudaSlice<i32>,
b: &CudaSlice<i32>,
n: usize,
op: &str,
d: &GpuDevice,
) -> GpuResult<CudaSlice<u8>> {
let name = format!("cmp_{op}_i32_kernel");
let ptx = cmp_ptx(&name, 2, "s32", ".reg .s32 %va, %vb;", &setp_for(op, "s32"));
launch_cmp::<i32>(a, b, n, d, ptx, name, "cmp_i32")
}
pub fn gpu_cmp_i64(
a: &CudaSlice<i64>,
b: &CudaSlice<i64>,
n: usize,
op: &str,
d: &GpuDevice,
) -> GpuResult<CudaSlice<u8>> {
let name = format!("cmp_{op}_i64_kernel");
let ptx = cmp_ptx(&name, 3, "s64", ".reg .s64 %va, %vb;", &setp_for(op, "s64"));
launch_cmp::<i64>(a, b, n, d, ptx, name, "cmp_i64")
}
pub fn gpu_cmp_bf16(
a: &CudaSlice<u16>,
b: &CudaSlice<u16>,
n: usize,
op: &str,
d: &GpuDevice,
) -> GpuResult<CudaSlice<u8>> {
let name = format!("cmp_{op}_bf16_kernel");
let setp = format!("setp.{op}.f32 %c, %fa, %fb;");
let ptx = cmp_half_ptx(&name, "sm_52", BF16_DECODE, &setp);
launch_cmp_half(a, b, n, d, ptx, name)
}
pub fn gpu_cmp_f16(
a: &CudaSlice<u16>,
b: &CudaSlice<u16>,
n: usize,
op: &str,
d: &GpuDevice,
) -> GpuResult<CudaSlice<u8>> {
let name = format!("cmp_{op}_f16_kernel");
let setp = format!("setp.{op}.f32 %c, %fa, %fb;");
let ptx = cmp_half_ptx(&name, "sm_53", F16_DECODE, &setp);
launch_cmp_half(a, b, n, d, ptx, name)
}
pub fn gpu_and_bool(
a: &CudaSlice<u8>,
b: &CudaSlice<u8>,
d: &GpuDevice,
) -> GpuResult<CudaSlice<u8>> {
let ptx = logic_bin_ptx("and_bool_kernel", "and");
launch_logic_bin(a, b, d, ptx, "and_bool_kernel")
}
pub fn gpu_or_bool(
a: &CudaSlice<u8>,
b: &CudaSlice<u8>,
d: &GpuDevice,
) -> GpuResult<CudaSlice<u8>> {
let ptx = logic_bin_ptx("or_bool_kernel", "or");
launch_logic_bin(a, b, d, ptx, "or_bool_kernel")
}
pub fn gpu_xor_bool(
a: &CudaSlice<u8>,
b: &CudaSlice<u8>,
d: &GpuDevice,
) -> GpuResult<CudaSlice<u8>> {
let ptx = logic_bin_ptx("xor_bool_kernel", "xor");
launch_logic_bin(a, b, d, ptx, "xor_bool_kernel")
}
pub fn gpu_not_bool(a: &CudaSlice<u8>, d: &GpuDevice) -> GpuResult<CudaSlice<u8>> {
launch_not(a, d)
}
pub fn gpu_any_bool(a: &CudaSlice<u8>, d: &GpuDevice) -> GpuResult<CudaSlice<u8>> {
launch_reduce_bool(a, d, REDUCE_ANY, 0)
}
pub fn gpu_all_bool(a: &CudaSlice<u8>, d: &GpuDevice) -> GpuResult<CudaSlice<u8>> {
launch_reduce_bool(a, d, REDUCE_ALL, 1)
}
pub const BOOL_BROADCAST_MAX_DIMS: usize = 8;
const BOOL_BROADCAST_PTX: &str = "\
.version 7.0
.target sm_52
.address_size 64
.visible .entry bool_broadcast_kernel(
.param .u64 input_ptr,
.param .u64 output_ptr,
.param .u32 n,
.param .u32 os0, .param .u32 os1, .param .u32 os2, .param .u32 os3,
.param .u32 os4, .param .u32 os5, .param .u32 os6, .param .u32 os7,
.param .u32 ss0, .param .u32 ss1, .param .u32 ss2, .param .u32 ss3,
.param .u32 ss4, .param .u32 ss5, .param .u32 ss6, .param .u32 ss7
) {
.reg .u32 %r_tid, %bid, %bdim, %n_reg;
.reg .u32 %flat, %src_idx, %coord, %tmp, %os, %ss;
.reg .u64 %in, %out, %off;
.reg .u16 %val;
.reg .pred %p;
ld.param.u64 %in, [input_ptr];
ld.param.u64 %out, [output_ptr];
ld.param.u32 %n_reg, [n];
mov.u32 %bid, %ctaid.x;
mov.u32 %bdim, %ntid.x;
mov.u32 %r_tid, %tid.x;
mad.lo.u32 %r_tid, %bid, %bdim, %r_tid;
setp.ge.u32 %p, %r_tid, %n_reg;
@%p bra DONE;
mov.u32 %flat, %r_tid;
mov.u32 %src_idx, 0;
// Dim 0
ld.param.u32 %os, [os0];
ld.param.u32 %ss, [ss0];
div.u32 %coord, %flat, %os;
mul.lo.u32 %tmp, %coord, %os;
sub.u32 %flat, %flat, %tmp;
mul.lo.u32 %tmp, %coord, %ss;
add.u32 %src_idx, %src_idx, %tmp;
// Dim 1
ld.param.u32 %os, [os1];
ld.param.u32 %ss, [ss1];
div.u32 %coord, %flat, %os;
mul.lo.u32 %tmp, %coord, %os;
sub.u32 %flat, %flat, %tmp;
mul.lo.u32 %tmp, %coord, %ss;
add.u32 %src_idx, %src_idx, %tmp;
// Dim 2
ld.param.u32 %os, [os2];
ld.param.u32 %ss, [ss2];
div.u32 %coord, %flat, %os;
mul.lo.u32 %tmp, %coord, %os;
sub.u32 %flat, %flat, %tmp;
mul.lo.u32 %tmp, %coord, %ss;
add.u32 %src_idx, %src_idx, %tmp;
// Dim 3
ld.param.u32 %os, [os3];
ld.param.u32 %ss, [ss3];
div.u32 %coord, %flat, %os;
mul.lo.u32 %tmp, %coord, %os;
sub.u32 %flat, %flat, %tmp;
mul.lo.u32 %tmp, %coord, %ss;
add.u32 %src_idx, %src_idx, %tmp;
// Dim 4
ld.param.u32 %os, [os4];
ld.param.u32 %ss, [ss4];
div.u32 %coord, %flat, %os;
mul.lo.u32 %tmp, %coord, %os;
sub.u32 %flat, %flat, %tmp;
mul.lo.u32 %tmp, %coord, %ss;
add.u32 %src_idx, %src_idx, %tmp;
// Dim 5
ld.param.u32 %os, [os5];
ld.param.u32 %ss, [ss5];
div.u32 %coord, %flat, %os;
mul.lo.u32 %tmp, %coord, %os;
sub.u32 %flat, %flat, %tmp;
mul.lo.u32 %tmp, %coord, %ss;
add.u32 %src_idx, %src_idx, %tmp;
// Dim 6
ld.param.u32 %os, [os6];
ld.param.u32 %ss, [ss6];
div.u32 %coord, %flat, %os;
mul.lo.u32 %tmp, %coord, %os;
sub.u32 %flat, %flat, %tmp;
mul.lo.u32 %tmp, %coord, %ss;
add.u32 %src_idx, %src_idx, %tmp;
// Dim 7
ld.param.u32 %os, [os7];
ld.param.u32 %ss, [ss7];
div.u32 %coord, %flat, %os;
mul.lo.u32 %tmp, %coord, %os;
sub.u32 %flat, %flat, %tmp;
mul.lo.u32 %tmp, %coord, %ss;
add.u32 %src_idx, %src_idx, %tmp;
// Load in[src_idx] (1 byte per element: byte offset == element index).
cvt.u64.u32 %off, %src_idx;
add.u64 %off, %in, %off;
ld.global.u8 %val, [%off];
// Store out[r_tid].
cvt.u64.u32 %off, %r_tid;
add.u64 %off, %out, %off;
st.global.u8 [%off], %val;
DONE:
ret;
}
";
fn pad_bool_broadcast_params(
out_shape: &[usize],
src_strides: &[usize],
n: usize,
) -> GpuResult<(
[u32; BOOL_BROADCAST_MAX_DIMS],
[u32; BOOL_BROADCAST_MAX_DIMS],
)> {
if out_shape.len() != src_strides.len() {
return Err(GpuError::ShapeMismatch {
op: "bool_broadcast_pad",
expected: vec![out_shape.len()],
got: vec![src_strides.len()],
});
}
if out_shape.len() > BOOL_BROADCAST_MAX_DIMS {
return Err(GpuError::ShapeMismatch {
op: "bool_broadcast_pad",
expected: vec![BOOL_BROADCAST_MAX_DIMS],
got: vec![out_shape.len()],
});
}
let rank = out_shape.len();
let mut out_stride = [0u32; BOOL_BROADCAST_MAX_DIMS];
if rank > 0 {
let mut acc: usize = 1;
for d in (0..rank).rev() {
if acc > u32::MAX as usize {
return Err(GpuError::ShapeMismatch {
op: "bool_broadcast_stride_overflow",
expected: vec![u32::MAX as usize],
got: vec![acc],
});
}
out_stride[d] = acc as u32;
acc = acc.saturating_mul(out_shape[d]);
}
}
let pad_val = (n as u32).saturating_add(1).max(1);
out_stride[rank..BOOL_BROADCAST_MAX_DIMS].fill(pad_val);
let mut src_stride_out = [0u32; BOOL_BROADCAST_MAX_DIMS];
for d in 0..rank {
let s = src_strides[d];
if s > u32::MAX as usize {
return Err(GpuError::ShapeMismatch {
op: "bool_broadcast_src_stride_overflow",
expected: vec![u32::MAX as usize],
got: vec![s],
});
}
src_stride_out[d] = s as u32;
}
Ok((out_stride, src_stride_out))
}
#[allow(clippy::too_many_lines, reason = "8-dim unrolled launch arg list")]
pub fn gpu_broadcast_bool(
input: &CudaSlice<u8>,
out_shape: &[usize],
src_strides: &[usize],
device: &GpuDevice,
) -> GpuResult<CudaSlice<u8>> {
let n: usize = if out_shape.is_empty() {
1
} else {
out_shape.iter().product()
};
let (out_stride, src_stride) = pad_bool_broadcast_params(out_shape, src_strides, n)?;
let stream = device.stream();
if n == 0 {
return Ok(stream.alloc_zeros::<u8>(0)?);
}
let ctx = device.context();
let f = get_or_compile(
ctx,
BOOL_BROADCAST_PTX,
"bool_broadcast_kernel",
device.ordinal() as u32,
)
.map_err(|e| GpuError::PtxCompileFailed {
kernel: "bool_broadcast_kernel",
source: e,
})?;
let mut out = stream.alloc_zeros::<u8>(n)?;
let cfg = launch_1d(n);
let n_u32 = n as u32;
unsafe {
stream
.launch_builder(&f)
.arg(input)
.arg(&mut out)
.arg(&n_u32)
.arg(&out_stride[0])
.arg(&out_stride[1])
.arg(&out_stride[2])
.arg(&out_stride[3])
.arg(&out_stride[4])
.arg(&out_stride[5])
.arg(&out_stride[6])
.arg(&out_stride[7])
.arg(&src_stride[0])
.arg(&src_stride[1])
.arg(&src_stride[2])
.arg(&src_stride[3])
.arg(&src_stride[4])
.arg(&src_stride[5])
.arg(&src_stride[6])
.arg(&src_stride[7])
.launch(cfg)?;
}
Ok(out)
}