use super::super::kernels;
#[cfg(target_arch = "x86_64")]
use super::super::kernels::simd::index as simd_index;
use super::super::{CpuClient, CpuRuntime};
use crate::dispatch_dtype;
use crate::dtype::DType;
use crate::error::{Error, Result};
use crate::ops::ScatterReduceOp;
use crate::runtime::ensure_contiguous;
use crate::tensor::Tensor;
fn normalize_indices_to_i64(
client: &CpuClient,
indices: &Tensor<CpuRuntime>,
) -> Result<Tensor<CpuRuntime>> {
match indices.dtype() {
DType::I64 => Ok(indices.clone()),
DType::I32 => {
let idx_i32: Vec<i32> = indices.to_vec();
let idx_i64: Vec<i64> = idx_i32.into_iter().map(i64::from).collect();
Ok(Tensor::<CpuRuntime>::from_slice(
&idx_i64,
indices.shape(),
&client.device,
))
}
other => Err(Error::DTypeMismatch {
lhs: DType::I64,
rhs: other,
}),
}
}
pub fn gather_2d_impl(
client: &CpuClient,
input: &Tensor<CpuRuntime>,
rows: &Tensor<CpuRuntime>,
cols: &Tensor<CpuRuntime>,
) -> Result<Tensor<CpuRuntime>> {
let dtype = input.dtype();
let shape = input.shape();
if shape.len() != 2 {
return Err(Error::ShapeMismatch {
expected: vec![0, 0], got: shape.to_vec(),
});
}
let nrows = shape[0];
let ncols = shape[1];
let rows_i64 = normalize_indices_to_i64(client, rows)?;
let cols_i64 = normalize_indices_to_i64(client, cols)?;
if rows_i64.ndim() != 1 {
return Err(Error::ShapeMismatch {
expected: vec![rows_i64.numel()],
got: rows_i64.shape().to_vec(),
});
}
if cols_i64.ndim() != 1 {
return Err(Error::ShapeMismatch {
expected: vec![cols_i64.numel()],
got: cols_i64.shape().to_vec(),
});
}
let num_indices = rows_i64.numel();
if cols_i64.numel() != num_indices {
return Err(Error::ShapeMismatch {
expected: vec![num_indices],
got: cols_i64.shape().to_vec(),
});
}
let input_contig = ensure_contiguous(input);
let rows_contig = ensure_contiguous(&rows_i64);
let cols_contig = ensure_contiguous(&cols_i64);
let out = Tensor::<CpuRuntime>::empty(&[num_indices], dtype, &client.device);
let input_ptr = input_contig.ptr();
let rows_ptr = rows_contig.ptr();
let cols_ptr = cols_contig.ptr();
let out_ptr = out.ptr();
dispatch_dtype!(dtype, T => {
let success = unsafe {
kernels::gather_2d_kernel::<T>(
input_ptr as *const T,
rows_ptr as *const i64,
cols_ptr as *const i64,
out_ptr as *mut T,
nrows,
ncols,
num_indices,
)
};
if !success {
return Err(Error::IndexOutOfBounds {
index: 0, size: nrows.max(ncols),
});
}
}, "gather_2d");
Ok(out)
}
pub fn gather_impl(
client: &CpuClient,
a: &Tensor<CpuRuntime>,
dim: usize,
index: &Tensor<CpuRuntime>,
) -> Result<Tensor<CpuRuntime>> {
let dtype = a.dtype();
let shape = a.shape();
let ndim = shape.len();
if dim >= ndim {
return Err(Error::InvalidDimension {
dim: dim as isize,
ndim,
});
}
let index_i64 = normalize_indices_to_i64(client, index)?;
if index_i64.ndim() != ndim {
return Err(Error::ShapeMismatch {
expected: shape.to_vec(),
got: index_i64.shape().to_vec(),
});
}
let out_shape = index_i64.shape().to_vec();
let a_contig = ensure_contiguous(a);
let index_contig = ensure_contiguous(&index_i64);
let out = Tensor::<CpuRuntime>::empty(&out_shape, dtype, &client.device);
let a_ptr = a_contig.ptr();
let index_ptr = index_contig.ptr();
let out_ptr = out.ptr();
dispatch_dtype!(dtype, T => {
unsafe {
kernels::gather_kernel::<T>(
a_ptr as *const T,
index_ptr as *const i64,
out_ptr as *mut T,
shape,
&out_shape,
dim,
);
}
}, "gather");
Ok(out)
}
pub fn scatter_impl(
client: &CpuClient,
a: &Tensor<CpuRuntime>,
dim: usize,
index: &Tensor<CpuRuntime>,
src: &Tensor<CpuRuntime>,
) -> Result<Tensor<CpuRuntime>> {
let dtype = a.dtype();
let shape = a.shape();
let ndim = shape.len();
if dim >= ndim {
return Err(Error::InvalidDimension {
dim: dim as isize,
ndim,
});
}
let index_i64 = normalize_indices_to_i64(client, index)?;
if src.dtype() != dtype {
return Err(Error::DTypeMismatch {
lhs: dtype,
rhs: src.dtype(),
});
}
if index_i64.shape() != src.shape() {
return Err(Error::ShapeMismatch {
expected: src.shape().to_vec(),
got: index_i64.shape().to_vec(),
});
}
let a_contig = ensure_contiguous(a);
let index_contig = ensure_contiguous(&index_i64);
let src_contig = ensure_contiguous(src);
let out = Tensor::<CpuRuntime>::empty(shape, dtype, &client.device);
let a_ptr = a_contig.ptr();
let index_ptr = index_contig.ptr();
let src_ptr = src_contig.ptr();
let out_ptr = out.ptr();
dispatch_dtype!(dtype, T => {
unsafe {
kernels::scatter_kernel::<T>(
a_ptr as *const T,
index_ptr as *const i64,
src_ptr as *const T,
out_ptr as *mut T,
shape,
index_i64.shape(),
dim,
);
}
}, "scatter");
Ok(out)
}
pub fn index_select_impl(
client: &CpuClient,
a: &Tensor<CpuRuntime>,
dim: usize,
index: &Tensor<CpuRuntime>,
) -> Result<Tensor<CpuRuntime>> {
let dtype = a.dtype();
let shape = a.shape();
let ndim = shape.len();
if dim >= ndim {
return Err(Error::InvalidDimension {
dim: dim as isize,
ndim,
});
}
let index_i64 = normalize_indices_to_i64(client, index)?;
if index_i64.ndim() != 1 {
return Err(Error::ShapeMismatch {
expected: vec![index_i64.numel()],
got: index_i64.shape().to_vec(),
});
}
let index_len = index_i64.shape()[0];
let mut out_shape = shape.to_vec();
out_shape[dim] = index_len;
let a_contig = ensure_contiguous(a);
let index_contig = ensure_contiguous(&index_i64);
let dim_size = shape[dim];
let index_data =
unsafe { std::slice::from_raw_parts(index_contig.ptr() as *const i64, index_len) };
for &idx in index_data.iter() {
if idx < 0 || idx as usize >= dim_size {
return Err(Error::IndexOutOfBounds {
index: if idx < 0 { 0 } else { idx as usize },
size: dim_size,
});
}
}
let out = Tensor::<CpuRuntime>::empty(&out_shape, dtype, &client.device);
let a_ptr = a_contig.ptr();
let index_ptr = index_contig.ptr();
let out_ptr = out.ptr();
dispatch_dtype!(dtype, T => {
unsafe {
kernels::index_select_kernel::<T>(
a_ptr as *const T,
index_ptr as *const i64,
out_ptr as *mut T,
shape,
dim,
index_len,
);
}
}, "index_select");
Ok(out)
}
pub fn index_put_impl(
client: &CpuClient,
a: &Tensor<CpuRuntime>,
dim: usize,
index: &Tensor<CpuRuntime>,
src: &Tensor<CpuRuntime>,
) -> Result<Tensor<CpuRuntime>> {
let dtype = a.dtype();
let shape = a.shape();
let ndim = shape.len();
if dim >= ndim {
return Err(Error::InvalidDimension {
dim: dim as isize,
ndim,
});
}
let index_i64 = normalize_indices_to_i64(client, index)?;
if index_i64.ndim() != 1 {
return Err(Error::ShapeMismatch {
expected: vec![index_i64.numel()],
got: index_i64.shape().to_vec(),
});
}
if src.dtype() != dtype {
return Err(Error::DTypeMismatch {
lhs: dtype,
rhs: src.dtype(),
});
}
let index_len = index_i64.shape()[0];
let mut expected_src_shape = shape.to_vec();
expected_src_shape[dim] = index_len;
if src.shape() != expected_src_shape {
return Err(Error::ShapeMismatch {
expected: expected_src_shape,
got: src.shape().to_vec(),
});
}
let a_contig = ensure_contiguous(a);
let index_contig = ensure_contiguous(&index_i64);
let src_contig = ensure_contiguous(src);
let dim_size = shape[dim];
let index_data =
unsafe { std::slice::from_raw_parts(index_contig.ptr() as *const i64, index_len) };
for &idx in index_data.iter() {
if idx < 0 || idx as usize >= dim_size {
return Err(Error::IndexOutOfBounds {
index: if idx < 0 { 0 } else { idx as usize },
size: dim_size,
});
}
}
let out = Tensor::<CpuRuntime>::empty(shape, dtype, &client.device);
let a_ptr = a_contig.ptr();
let index_ptr = index_contig.ptr();
let src_ptr = src_contig.ptr();
let out_ptr = out.ptr();
dispatch_dtype!(dtype, T => {
unsafe {
kernels::index_put_kernel::<T>(
a_ptr as *const T,
index_ptr as *const i64,
src_ptr as *const T,
out_ptr as *mut T,
shape,
dim,
index_len,
);
}
}, "index_put");
Ok(out)
}
pub fn masked_select_impl(
client: &CpuClient,
a: &Tensor<CpuRuntime>,
mask: &Tensor<CpuRuntime>,
) -> Result<Tensor<CpuRuntime>> {
let dtype = a.dtype();
if mask.dtype() != DType::U8 {
return Err(Error::DTypeMismatch {
lhs: DType::U8,
rhs: mask.dtype(),
});
}
let mask_broadcast = mask.broadcast_to(a.shape())?;
let a_contig = ensure_contiguous(a);
let mask_contig = ensure_contiguous(&mask_broadcast);
let numel = a.numel();
let a_ptr = a_contig.ptr();
let mask_ptr = mask_contig.ptr();
#[cfg(target_arch = "x86_64")]
{
let count = unsafe { simd_index::masked_count(mask_ptr as *const u8, numel) };
let out = Tensor::<CpuRuntime>::empty(&[count], dtype, &client.device);
let out_ptr = out.ptr();
match dtype {
DType::F32 => {
unsafe {
simd_index::masked_select_f32(
a_ptr as *const f32,
mask_ptr as *const u8,
out_ptr as *mut f32,
numel,
);
}
return Ok(out);
}
DType::F64 => {
unsafe {
simd_index::masked_select_f64(
a_ptr as *const f64,
mask_ptr as *const u8,
out_ptr as *mut f64,
numel,
);
}
return Ok(out);
}
_ => {
dispatch_dtype!(dtype, T => {
unsafe {
kernels::masked_select_kernel::<T>(
a_ptr as *const T,
mask_ptr as *const u8,
out_ptr as *mut T,
numel,
);
}
}, "masked_select");
return Ok(out);
}
}
}
#[cfg(not(target_arch = "x86_64"))]
{
let count = unsafe { kernels::index::masked_count_kernel(mask_ptr as *const u8, numel) };
let out = Tensor::<CpuRuntime>::empty(&[count], dtype, &client.device);
let out_ptr = out.ptr();
dispatch_dtype!(dtype, T => {
unsafe {
kernels::masked_select_kernel::<T>(
a_ptr as *const T,
mask_ptr as *const u8,
out_ptr as *mut T,
numel,
);
}
}, "masked_select");
Ok(out)
}
}
pub fn masked_fill_impl(
client: &CpuClient,
a: &Tensor<CpuRuntime>,
mask: &Tensor<CpuRuntime>,
value: f64,
) -> Result<Tensor<CpuRuntime>> {
let dtype = a.dtype();
if mask.dtype() != DType::U8 {
return Err(Error::DTypeMismatch {
lhs: DType::U8,
rhs: mask.dtype(),
});
}
let mask_broadcast = mask.broadcast_to(a.shape())?;
let a_contig = ensure_contiguous(a);
let mask_contig = ensure_contiguous(&mask_broadcast);
let out = Tensor::<CpuRuntime>::empty(a.shape(), dtype, &client.device);
let numel = a.numel();
let a_ptr = a_contig.ptr();
let mask_ptr = mask_contig.ptr();
let out_ptr = out.ptr();
#[cfg(target_arch = "x86_64")]
{
match dtype {
DType::F32 => {
unsafe {
simd_index::masked_fill_f32(
a_ptr as *const f32,
mask_ptr as *const u8,
out_ptr as *mut f32,
numel,
value as f32,
);
}
return Ok(out);
}
DType::F64 => {
unsafe {
simd_index::masked_fill_f64(
a_ptr as *const f64,
mask_ptr as *const u8,
out_ptr as *mut f64,
numel,
value,
);
}
return Ok(out);
}
_ => {} }
}
dispatch_dtype!(dtype, T => {
unsafe {
kernels::masked_fill_kernel::<T>(
a_ptr as *const T,
mask_ptr as *const u8,
out_ptr as *mut T,
numel,
value,
);
}
}, "masked_fill");
Ok(out)
}
pub fn embedding_lookup_impl(
client: &CpuClient,
embeddings: &Tensor<CpuRuntime>,
indices: &Tensor<CpuRuntime>,
) -> Result<Tensor<CpuRuntime>> {
let dtype = embeddings.dtype();
let emb_shape = embeddings.shape();
if emb_shape.len() != 2 {
return Err(Error::ShapeMismatch {
expected: vec![0, 0], got: emb_shape.to_vec(),
});
}
let indices_i64 = normalize_indices_to_i64(client, indices)?;
let vocab_size = emb_shape[0];
let embedding_dim = emb_shape[1];
let num_indices = indices_i64.numel();
let mut out_shape = indices_i64.shape().to_vec();
out_shape.push(embedding_dim);
let emb_contig = ensure_contiguous(embeddings);
let idx_contig = ensure_contiguous(&indices_i64);
let out = Tensor::<CpuRuntime>::empty(&out_shape, dtype, &client.device);
let emb_ptr = emb_contig.ptr();
let idx_ptr = idx_contig.ptr();
let out_ptr = out.ptr();
dispatch_dtype!(dtype, T => {
unsafe {
kernels::embedding_lookup_kernel::<T>(
emb_ptr as *const T,
idx_ptr as *const i64,
out_ptr as *mut T,
num_indices,
vocab_size,
embedding_dim,
);
}
}, "embedding_lookup");
Ok(out)
}
#[allow(clippy::too_many_arguments)]
pub fn scatter_reduce_impl(
client: &CpuClient,
dst: &Tensor<CpuRuntime>,
dim: usize,
index: &Tensor<CpuRuntime>,
src: &Tensor<CpuRuntime>,
op: ScatterReduceOp,
include_self: bool,
) -> Result<Tensor<CpuRuntime>> {
let dtype = dst.dtype();
let shape = dst.shape();
let ndim = shape.len();
if dim >= ndim {
return Err(Error::InvalidDimension {
dim: dim as isize,
ndim,
});
}
let index_i64 = normalize_indices_to_i64(client, index)?;
if src.dtype() != dtype {
return Err(Error::DTypeMismatch {
lhs: dtype,
rhs: src.dtype(),
});
}
if index_i64.shape() != src.shape() {
return Err(Error::ShapeMismatch {
expected: src.shape().to_vec(),
got: index_i64.shape().to_vec(),
});
}
if index_i64.ndim() != ndim {
return Err(Error::ShapeMismatch {
expected: shape.to_vec(),
got: index_i64.shape().to_vec(),
});
}
let dst_contig = ensure_contiguous(dst);
let index_contig = ensure_contiguous(&index_i64);
let src_contig = ensure_contiguous(src);
let out = Tensor::<CpuRuntime>::empty(shape, dtype, &client.device);
let dst_numel: usize = shape.iter().product();
let counts_buffer: Vec<u32> = vec![0; dst_numel];
let dst_ptr = dst_contig.ptr();
let index_ptr = index_contig.ptr();
let src_ptr = src_contig.ptr();
let out_ptr = out.ptr();
let counts_ptr = if op == ScatterReduceOp::Mean {
counts_buffer.as_ptr() as *mut u32
} else {
std::ptr::null_mut()
};
dispatch_dtype!(dtype, T => {
unsafe {
kernels::scatter_reduce_kernel::<T>(
dst_ptr as *const T,
index_ptr as *const i64,
src_ptr as *const T,
out_ptr as *mut T,
counts_ptr,
shape,
index_i64.shape(),
dim,
op,
include_self,
);
}
}, "scatter_reduce");
Ok(out)
}
pub fn gather_nd_impl(
client: &CpuClient,
input: &Tensor<CpuRuntime>,
indices: &Tensor<CpuRuntime>,
) -> Result<Tensor<CpuRuntime>> {
let dtype = input.dtype();
let input_shape = input.shape();
let indices_i64 = normalize_indices_to_i64(client, indices)?;
let indices_shape = indices_i64.shape();
if indices_shape.is_empty() {
return Err(Error::ShapeMismatch {
expected: vec![1],
got: indices_shape.to_vec(),
});
}
let indices_ndim = indices_shape.len();
let index_depth = indices_shape[indices_ndim - 1];
if index_depth > input_shape.len() {
return Err(Error::InvalidDimension {
dim: index_depth as isize,
ndim: input_shape.len(),
});
}
let mut out_shape: Vec<usize> = indices_shape[..indices_ndim - 1].to_vec();
out_shape.extend_from_slice(&input_shape[index_depth..]);
if out_shape.is_empty() {
out_shape.push(1);
}
let input_contig = ensure_contiguous(input);
let indices_contig = ensure_contiguous(&indices_i64);
let out = Tensor::<CpuRuntime>::empty(&out_shape, dtype, &client.device);
let input_ptr = input_contig.ptr();
let indices_ptr = indices_contig.ptr();
let out_ptr = out.ptr();
dispatch_dtype!(dtype, T => {
unsafe {
kernels::gather_nd_kernel::<T>(
input_ptr as *const T,
indices_ptr as *const i64,
out_ptr as *mut T,
input_shape,
indices_shape,
&out_shape,
);
}
}, "gather_nd");
Ok(out)
}
pub fn bincount_impl(
client: &CpuClient,
input: &Tensor<CpuRuntime>,
weights: Option<&Tensor<CpuRuntime>>,
minlength: usize,
) -> Result<Tensor<CpuRuntime>> {
if input.ndim() != 1 {
return Err(Error::ShapeMismatch {
expected: vec![input.numel()],
got: input.shape().to_vec(),
});
}
let input_dtype = input.dtype();
if !matches!(input_dtype, DType::I32 | DType::I64) {
return Err(Error::DTypeMismatch {
lhs: DType::I64,
rhs: input_dtype,
});
}
let out_dtype = if let Some(w) = weights {
if w.shape() != input.shape() {
return Err(Error::ShapeMismatch {
expected: input.shape().to_vec(),
got: w.shape().to_vec(),
});
}
w.dtype()
} else {
DType::I64 };
let input_contig = ensure_contiguous(input);
let numel = input.numel();
let input_i64: Vec<i64> = if input_dtype == DType::I64 {
unsafe { std::slice::from_raw_parts(input_contig.ptr() as *const i64, numel).to_vec() }
} else {
let i32_slice =
unsafe { std::slice::from_raw_parts(input_contig.ptr() as *const i32, numel) };
i32_slice.iter().map(|&x| x as i64).collect()
};
let max_val = unsafe { kernels::max_i64_kernel(input_i64.as_ptr(), numel) };
if max_val < 0 {
return Err(Error::InvalidArgument {
arg: "input",
reason: "bincount requires non-negative values".to_string(),
});
}
let output_len = (max_val as usize + 1).max(minlength);
let out = Tensor::<CpuRuntime>::empty(&[output_len], out_dtype, &client.device);
let out_ptr = out.ptr();
if let Some(w) = weights {
let w_contig = ensure_contiguous(w);
let w_ptr = w_contig.ptr();
dispatch_dtype!(out_dtype, T => {
let success = unsafe {
kernels::bincount_kernel::<T>(
input_i64.as_ptr(),
w_ptr as *const T,
out_ptr as *mut T,
numel,
output_len,
)
};
if !success {
return Err(Error::InvalidArgument {
arg: "input",
reason: "bincount requires non-negative values".to_string(),
});
}
}, "bincount");
} else {
let success = unsafe {
kernels::bincount_kernel::<i64>(
input_i64.as_ptr(),
std::ptr::null(),
out_ptr as *mut i64,
numel,
output_len,
)
};
if !success {
return Err(Error::InvalidArgument {
arg: "input",
reason: "bincount requires non-negative values".to_string(),
});
}
}
Ok(out)
}
pub fn slice_assign_impl(
client: &CpuClient,
dst: &Tensor<CpuRuntime>,
src: &Tensor<CpuRuntime>,
dim: usize,
start: usize,
) -> Result<Tensor<CpuRuntime>> {
let ndim = dst.ndim();
if dim >= ndim {
return Err(Error::InvalidDimension {
dim: dim as isize,
ndim,
});
}
if src.ndim() != ndim {
return Err(Error::ShapeMismatch {
expected: dst.shape().to_vec(),
got: src.shape().to_vec(),
});
}
for d in 0..ndim {
if d != dim && src.shape()[d] != dst.shape()[d] {
return Err(Error::ShapeMismatch {
expected: dst.shape().to_vec(),
got: src.shape().to_vec(),
});
}
}
let src_dim_size = src.shape()[dim];
let dst_dim_size = dst.shape()[dim];
if start + src_dim_size > dst_dim_size {
return Err(Error::InvalidArgument {
arg: "start",
reason: format!(
"start ({}) + src dim size ({}) exceeds dst dim size ({})",
start, src_dim_size, dst_dim_size
),
});
}
let dtype = dst.dtype();
if src.dtype() != dtype {
return Err(Error::DTypeMismatch {
lhs: dtype,
rhs: src.dtype(),
});
}
let outer_size: usize = dst.shape()[..dim].iter().product();
let outer_size = if outer_size == 0 { 1 } else { outer_size };
let inner_size: usize = dst.shape()[dim + 1..].iter().product();
let inner_size = if inner_size == 0 { 1 } else { inner_size };
let dst_c = ensure_contiguous(dst);
let src_c = ensure_contiguous(src);
let out = Tensor::<CpuRuntime>::empty(dst.shape(), dtype, &client.device);
let dst_ptr = dst_c.ptr();
let src_ptr = src_c.ptr();
let out_ptr = out.ptr();
dispatch_dtype!(dtype, T => {
unsafe {
kernels::slice_assign_kernel::<T>(
dst_ptr as *const T,
src_ptr as *const T,
out_ptr as *mut T,
outer_size,
dst_dim_size,
src_dim_size,
inner_size,
start,
);
}
}, "slice_assign");
Ok(out)
}