use crate::dtype::Element;
use crate::ops::ScatterReduceOp;
#[inline]
#[allow(clippy::too_many_arguments)]
pub unsafe fn gather_kernel<T: Element>(
a: *const T,
indices: *const i64,
out: *mut T,
shape: &[usize],
index_shape: &[usize],
dim: usize,
) {
let ndim = shape.len();
if ndim == 0 {
return;
}
let mut a_strides = vec![1usize; ndim];
for i in (0..ndim - 1).rev() {
a_strides[i] = a_strides[i + 1] * shape[i + 1];
}
let mut idx_strides = vec![1usize; ndim];
for i in (0..ndim - 1).rev() {
idx_strides[i] = idx_strides[i + 1] * index_shape[i + 1];
}
let total = index_shape.iter().product::<usize>();
for out_idx in 0..total {
let mut remaining = out_idx;
let mut multi_idx = vec![0usize; ndim];
for d in 0..ndim {
multi_idx[d] = remaining / idx_strides[d];
remaining %= idx_strides[d];
}
let index_val = *indices.add(out_idx);
if index_val < 0 || index_val as usize >= shape[dim] {
*out.add(out_idx) = T::zero();
continue;
}
let mut src_offset = 0;
for d in 0..ndim {
let coord = if d == dim {
index_val as usize
} else {
multi_idx[d]
};
src_offset += coord * a_strides[d];
}
*out.add(out_idx) = *a.add(src_offset);
}
}
#[inline]
#[allow(clippy::too_many_arguments)]
pub unsafe fn scatter_kernel<T: Element>(
a: *const T,
indices: *const i64,
src: *const T,
out: *mut T,
shape: &[usize],
index_shape: &[usize],
dim: usize,
) {
let ndim = shape.len();
if ndim == 0 {
return;
}
let a_numel: usize = shape.iter().product();
std::ptr::copy_nonoverlapping(a, out, a_numel);
let mut out_strides = vec![1usize; ndim];
for i in (0..ndim - 1).rev() {
out_strides[i] = out_strides[i + 1] * shape[i + 1];
}
let mut idx_strides = vec![1usize; ndim];
for i in (0..ndim - 1).rev() {
idx_strides[i] = idx_strides[i + 1] * index_shape[i + 1];
}
let total = index_shape.iter().product::<usize>();
for src_idx in 0..total {
let mut remaining = src_idx;
let mut multi_idx = vec![0usize; ndim];
for d in 0..ndim {
multi_idx[d] = remaining / idx_strides[d];
remaining %= idx_strides[d];
}
let index_val = *indices.add(src_idx);
if index_val < 0 || index_val as usize >= shape[dim] {
continue;
}
let mut dst_offset = 0;
for d in 0..ndim {
let coord = if d == dim {
index_val as usize
} else {
multi_idx[d]
};
dst_offset += coord * out_strides[d];
}
*out.add(dst_offset) = *src.add(src_idx);
}
}
#[inline]
#[allow(clippy::too_many_arguments)]
pub unsafe fn index_select_kernel<T: Element>(
a: *const T,
indices: *const i64,
out: *mut T,
shape: &[usize],
dim: usize,
index_len: usize,
) {
let ndim = shape.len();
if ndim == 0 {
return;
}
let outer_size: usize = shape[..dim].iter().product();
let dim_size = shape[dim];
let inner_size: usize = shape[dim + 1..].iter().product();
for outer in 0..outer_size.max(1) {
for (sel_idx, &idx_ptr) in std::slice::from_raw_parts(indices, index_len)
.iter()
.enumerate()
{
let idx = idx_ptr as usize;
if idx >= dim_size {
for inner in 0..inner_size.max(1) {
let out_offset =
outer * index_len * inner_size.max(1) + sel_idx * inner_size.max(1) + inner;
*out.add(out_offset) = T::zero();
}
continue;
}
for inner in 0..inner_size.max(1) {
let src_offset =
outer * dim_size * inner_size.max(1) + idx * inner_size.max(1) + inner;
let out_offset =
outer * index_len * inner_size.max(1) + sel_idx * inner_size.max(1) + inner;
*out.add(out_offset) = *a.add(src_offset);
}
}
}
}
#[inline]
#[allow(clippy::too_many_arguments)]
pub unsafe fn index_put_kernel<T: Element>(
a: *const T,
indices: *const i64,
src: *const T,
out: *mut T,
shape: &[usize],
dim: usize,
index_len: usize,
) {
let ndim = shape.len();
if ndim == 0 {
return;
}
let outer_size: usize = shape[..dim].iter().product();
let dim_size = shape[dim];
let inner_size: usize = shape[dim + 1..].iter().product();
let total_size: usize = shape.iter().product();
std::ptr::copy_nonoverlapping(a, out, total_size);
for outer in 0..outer_size.max(1) {
for (sel_idx, &idx_ptr) in std::slice::from_raw_parts(indices, index_len)
.iter()
.enumerate()
{
let idx = idx_ptr as usize;
if idx >= dim_size {
continue;
}
for inner in 0..inner_size.max(1) {
let out_offset =
outer * dim_size * inner_size.max(1) + idx * inner_size.max(1) + inner;
let src_offset =
outer * index_len * inner_size.max(1) + sel_idx * inner_size.max(1) + inner;
*out.add(out_offset) = *src.add(src_offset);
}
}
}
}
#[inline]
#[allow(dead_code)] pub unsafe fn masked_count_kernel(mask: *const u8, numel: usize) -> usize {
#[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
{
use super::simd::index;
return index::masked_count(mask, numel);
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
{
let mask_slice = std::slice::from_raw_parts(mask, numel);
mask_slice.iter().filter(|&&m| m != 0).count()
}
}
#[inline]
pub unsafe fn masked_select_kernel<T: Element>(
a: *const T,
mask: *const u8,
out: *mut T,
numel: usize,
) {
#[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
{
use super::simd::index;
if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>() {
let _ = index::masked_select_f32(a as *const f32, mask, out as *mut f32, numel);
return;
} else if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f64>() {
let _ = index::masked_select_f64(a as *const f64, mask, out as *mut f64, numel);
return;
}
}
let a_slice = std::slice::from_raw_parts(a, numel);
let mask_slice = std::slice::from_raw_parts(mask, numel);
let mut out_idx = 0;
for i in 0..numel {
if mask_slice[i] != 0 {
*out.add(out_idx) = a_slice[i];
out_idx += 1;
}
}
}
#[inline]
pub unsafe fn masked_fill_kernel<T: Element>(
a: *const T,
mask: *const u8,
out: *mut T,
numel: usize,
value: f64,
) {
#[cfg(any(target_arch = "x86_64", target_arch = "aarch64"))]
{
use super::simd::index;
if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f32>() {
index::masked_fill_f32(a as *const f32, mask, out as *mut f32, numel, value as f32);
return;
} else if std::any::TypeId::of::<T>() == std::any::TypeId::of::<f64>() {
index::masked_fill_f64(a as *const f64, mask, out as *mut f64, numel, value);
return;
}
}
let a_slice = std::slice::from_raw_parts(a, numel);
let mask_slice = std::slice::from_raw_parts(mask, numel);
let out_slice = std::slice::from_raw_parts_mut(out, numel);
let fill_val = T::from_f64(value);
for i in 0..numel {
out_slice[i] = if mask_slice[i] != 0 {
fill_val
} else {
a_slice[i]
};
}
}
#[inline]
pub unsafe fn embedding_lookup_kernel<T: Element>(
embeddings: *const T,
indices: *const i64,
out: *mut T,
num_indices: usize,
vocab_size: usize,
embedding_dim: usize,
) {
if num_indices == 0 || embedding_dim == 0 {
return;
}
let indices_slice = std::slice::from_raw_parts(indices, num_indices);
for (i, &idx_val) in indices_slice.iter().enumerate() {
let out_offset = i * embedding_dim;
if idx_val < 0 || idx_val as usize >= vocab_size {
let out_slice = std::slice::from_raw_parts_mut(out.add(out_offset), embedding_dim);
for elem in out_slice {
*elem = T::zero();
}
continue;
}
let src_offset = (idx_val as usize) * embedding_dim;
std::ptr::copy_nonoverlapping(
embeddings.add(src_offset),
out.add(out_offset),
embedding_dim,
);
}
}
#[inline]
#[allow(clippy::too_many_arguments)]
pub unsafe fn scatter_reduce_kernel<T: Element>(
dst: *const T,
indices: *const i64,
src: *const T,
out: *mut T,
counts: *mut u32,
shape: &[usize],
index_shape: &[usize],
dim: usize,
op: ScatterReduceOp,
include_self: bool,
) {
let ndim = shape.len();
if ndim == 0 {
return;
}
let dst_numel: usize = shape.iter().product();
if include_self {
std::ptr::copy_nonoverlapping(dst, out, dst_numel);
if op == ScatterReduceOp::Mean && !counts.is_null() {
let counts_slice = std::slice::from_raw_parts_mut(counts, dst_numel);
for c in counts_slice.iter_mut() {
*c = 1;
}
}
} else {
let out_slice = std::slice::from_raw_parts_mut(out, dst_numel);
match op {
ScatterReduceOp::Sum | ScatterReduceOp::Mean => {
for elem in out_slice.iter_mut() {
*elem = T::zero();
}
}
ScatterReduceOp::Prod => {
for elem in out_slice.iter_mut() {
*elem = T::one();
}
}
ScatterReduceOp::Max => {
for elem in out_slice.iter_mut() {
*elem = T::from_f64(f64::NEG_INFINITY);
}
}
ScatterReduceOp::Min => {
for elem in out_slice.iter_mut() {
*elem = T::from_f64(f64::INFINITY);
}
}
}
if op == ScatterReduceOp::Mean && !counts.is_null() {
let counts_slice = std::slice::from_raw_parts_mut(counts, dst_numel);
for c in counts_slice.iter_mut() {
*c = 0;
}
}
}
let mut out_strides = vec![1usize; ndim];
for i in (0..ndim - 1).rev() {
out_strides[i] = out_strides[i + 1] * shape[i + 1];
}
let mut idx_strides = vec![1usize; ndim];
for i in (0..ndim - 1).rev() {
idx_strides[i] = idx_strides[i + 1] * index_shape[i + 1];
}
let total = index_shape.iter().product::<usize>();
for src_idx in 0..total {
let mut remaining = src_idx;
let mut multi_idx = vec![0usize; ndim];
for d in 0..ndim {
multi_idx[d] = remaining / idx_strides[d];
remaining %= idx_strides[d];
}
let index_val = *indices.add(src_idx);
if index_val < 0 || index_val as usize >= shape[dim] {
continue;
}
let mut dst_offset = 0;
for d in 0..ndim {
let coord = if d == dim {
index_val as usize
} else {
multi_idx[d]
};
dst_offset += coord * out_strides[d];
}
let src_val = *src.add(src_idx);
let dst_val = *out.add(dst_offset);
let new_val = match op {
ScatterReduceOp::Sum | ScatterReduceOp::Mean => dst_val + src_val,
ScatterReduceOp::Prod => dst_val * src_val,
ScatterReduceOp::Max => {
if src_val.to_f64() > dst_val.to_f64() {
src_val
} else {
dst_val
}
}
ScatterReduceOp::Min => {
if src_val.to_f64() < dst_val.to_f64() {
src_val
} else {
dst_val
}
}
};
*out.add(dst_offset) = new_val;
if op == ScatterReduceOp::Mean && !counts.is_null() {
*counts.add(dst_offset) += 1;
}
}
if op == ScatterReduceOp::Mean && !counts.is_null() {
let out_slice = std::slice::from_raw_parts_mut(out, dst_numel);
let counts_slice = std::slice::from_raw_parts(counts, dst_numel);
for (elem, &count) in out_slice.iter_mut().zip(counts_slice.iter()) {
if count > 0 {
*elem = T::from_f64(elem.to_f64() / count as f64);
}
}
}
}
#[inline]
#[allow(clippy::too_many_arguments)]
pub unsafe fn gather_nd_kernel<T: Element>(
input: *const T,
indices: *const i64,
out: *mut T,
input_shape: &[usize],
indices_shape: &[usize],
out_shape: &[usize],
) {
if input_shape.is_empty() || indices_shape.is_empty() {
return;
}
let input_ndim = input_shape.len();
let indices_ndim = indices_shape.len();
let index_depth = indices_shape[indices_ndim - 1];
let mut input_strides = vec![1usize; input_ndim];
for i in (0..input_ndim - 1).rev() {
input_strides[i] = input_strides[i + 1] * input_shape[i + 1];
}
let mut indices_strides = vec![1usize; indices_ndim];
for i in (0..indices_ndim - 1).rev() {
indices_strides[i] = indices_strides[i + 1] * indices_shape[i + 1];
}
let out_ndim = out_shape.len();
let mut out_strides = vec![1usize; out_ndim.max(1)];
for i in (0..out_ndim.saturating_sub(1)).rev() {
out_strides[i] = out_strides[i + 1] * out_shape[i + 1];
}
let num_indices: usize = indices_shape[..indices_ndim - 1]
.iter()
.product::<usize>()
.max(1);
let trailing_size: usize = if index_depth < input_ndim {
input_shape[index_depth..].iter().product()
} else {
1
};
for idx_vec in 0..num_indices {
let indices_offset = idx_vec * index_depth;
let mut input_offset = 0usize;
let mut valid = true;
for d in 0..index_depth {
let coord = *indices.add(indices_offset + d);
if coord < 0 || coord as usize >= input_shape[d] {
valid = false;
break;
}
input_offset += (coord as usize) * input_strides[d];
}
let out_offset = idx_vec * trailing_size;
if !valid {
for i in 0..trailing_size {
*out.add(out_offset + i) = T::zero();
}
} else {
for i in 0..trailing_size {
*out.add(out_offset + i) = *input.add(input_offset + i);
}
}
}
}
#[inline]
pub unsafe fn bincount_kernel<T: Element>(
input: *const i64,
weights: *const T,
out: *mut T,
numel: usize,
output_len: usize,
) -> bool {
let out_slice = std::slice::from_raw_parts_mut(out, output_len);
for elem in out_slice.iter_mut() {
*elem = T::zero();
}
let input_slice = std::slice::from_raw_parts(input, numel);
let has_weights = !weights.is_null();
for i in 0..numel {
let val = input_slice[i];
if val < 0 {
return false; }
let idx = val as usize;
if idx < output_len {
if has_weights {
let w = *weights.add(i);
out_slice[idx] = out_slice[idx] + w;
} else {
out_slice[idx] = out_slice[idx] + T::one();
}
}
}
true
}
#[inline]
pub unsafe fn max_i64_kernel(input: *const i64, numel: usize) -> i64 {
if numel == 0 {
return -1;
}
let slice = std::slice::from_raw_parts(input, numel);
*slice.iter().max().unwrap_or(&-1)
}
#[inline]
pub unsafe fn gather_2d_kernel<T: Element>(
input: *const T,
rows: *const i64,
cols: *const i64,
out: *mut T,
nrows: usize,
ncols: usize,
num_indices: usize,
) -> bool {
if num_indices == 0 {
return true;
}
let rows_slice = std::slice::from_raw_parts(rows, num_indices);
let cols_slice = std::slice::from_raw_parts(cols, num_indices);
for i in 0..num_indices {
let r = rows_slice[i];
let c = cols_slice[i];
if r < 0 || r as usize >= nrows || c < 0 || c as usize >= ncols {
return false;
}
let input_offset = (r as usize) * ncols + (c as usize);
*out.add(i) = *input.add(input_offset);
}
true
}
pub unsafe fn slice_assign_kernel<T: Copy>(
dst: *const T,
src: *const T,
out: *mut T,
outer_size: usize,
dst_dim_size: usize,
src_dim_size: usize,
inner_size: usize,
start: usize,
) {
let dst_total = outer_size * dst_dim_size * inner_size;
std::ptr::copy_nonoverlapping(dst, out, dst_total);
for o in 0..outer_size {
for s in 0..src_dim_size {
let src_offset = o * src_dim_size * inner_size + s * inner_size;
let dst_offset = o * dst_dim_size * inner_size + (start + s) * inner_size;
std::ptr::copy_nonoverlapping(src.add(src_offset), out.add(dst_offset), inner_size);
}
}
}