use crate::algorithm::linalg::helpers::{linalg_demote, linalg_promote};
use crate::dtype::DType;
use crate::error::{Error, Result};
use crate::ops::{ReduceOps, ScatterReduceOp, TypeConversionOps};
use crate::runtime::cuda::kernels::{
ScatterReduceOpCuda, launch_bincount_weighted, launch_copy, launch_embedding_lookup,
launch_fill_with_f64, launch_gather_nd, launch_scatter_reduce, launch_scatter_reduce_count,
launch_scatter_reduce_mean_div,
};
use crate::runtime::cuda::{CudaClient, CudaRuntime};
use crate::runtime::{Runtime, compute_contiguous_strides, ensure_contiguous};
use crate::tensor::Tensor;
use super::helpers::normalize_indices_to_i64;
pub fn embedding_lookup(
client: &CudaClient,
embeddings: &Tensor<CudaRuntime>,
indices: &Tensor<CudaRuntime>,
) -> Result<Tensor<CudaRuntime>> {
let dtype = embeddings.dtype();
let emb_shape = embeddings.shape();
if emb_shape.len() != 2 {
return Err(Error::ShapeMismatch {
expected: vec![0, 0], got: emb_shape.to_vec(),
});
}
let indices_i64 = normalize_indices_to_i64(client, indices)?;
let vocab_size = emb_shape[0];
let embedding_dim = emb_shape[1];
let num_indices = indices_i64.numel();
let mut out_shape = indices_i64.shape().to_vec();
out_shape.push(embedding_dim);
let emb_contig = ensure_contiguous(embeddings);
let idx_contig = ensure_contiguous(&indices_i64);
let out = Tensor::<CudaRuntime>::empty(&out_shape, dtype, &client.device);
unsafe {
launch_embedding_lookup(
&client.context,
&client.stream,
client.device.index,
dtype,
emb_contig.ptr(),
idx_contig.ptr(),
out.ptr(),
num_indices,
vocab_size,
embedding_dim,
)?;
}
Ok(out)
}
pub fn scatter_reduce(
client: &CudaClient,
dst: &Tensor<CudaRuntime>,
dim: usize,
index: &Tensor<CudaRuntime>,
src: &Tensor<CudaRuntime>,
op: ScatterReduceOp,
include_self: bool,
) -> Result<Tensor<CudaRuntime>> {
let dtype = dst.dtype();
if dtype.is_float() && !matches!(dtype, DType::F32 | DType::F64) {
let (dst_promoted, orig_dtype) = linalg_promote(client, dst)?;
let (src_promoted, _) = linalg_promote(client, src)?;
let result = scatter_reduce(
client,
&dst_promoted,
dim,
index,
&src_promoted,
op,
include_self,
)?;
return linalg_demote(client, result, orig_dtype);
}
let shape = dst.shape();
let ndim = shape.len();
if dim >= ndim {
return Err(Error::InvalidDimension {
dim: dim as isize,
ndim,
});
}
let index_i64 = normalize_indices_to_i64(client, index)?;
if src.dtype() != dtype {
return Err(Error::DTypeMismatch {
lhs: dtype,
rhs: src.dtype(),
});
}
if index_i64.shape() != src.shape() {
return Err(Error::ShapeMismatch {
expected: src.shape().to_vec(),
got: index_i64.shape().to_vec(),
});
}
if index_i64.ndim() != ndim {
return Err(Error::ShapeMismatch {
expected: shape.to_vec(),
got: index_i64.shape().to_vec(),
});
}
let cuda_op = match op {
ScatterReduceOp::Sum => ScatterReduceOpCuda::Sum,
ScatterReduceOp::Max => ScatterReduceOpCuda::Max,
ScatterReduceOp::Min => ScatterReduceOpCuda::Min,
ScatterReduceOp::Prod => ScatterReduceOpCuda::Prod,
ScatterReduceOp::Mean => ScatterReduceOpCuda::Sum, };
let dst_contig = ensure_contiguous(dst);
let index_contig = ensure_contiguous(&index_i64);
let src_contig = ensure_contiguous(src);
let out = Tensor::<CudaRuntime>::empty(shape, dtype, &client.device);
if include_self {
unsafe {
launch_copy(
&client.context,
&client.stream,
client.device.index,
dtype,
dst_contig.ptr(),
out.ptr(),
dst.numel(),
)?;
}
} else {
let identity = match op {
ScatterReduceOp::Sum | ScatterReduceOp::Mean => 0.0,
ScatterReduceOp::Max => f64::NEG_INFINITY,
ScatterReduceOp::Min => f64::INFINITY,
ScatterReduceOp::Prod => 1.0,
};
unsafe {
launch_fill_with_f64(
&client.context,
&client.stream,
client.device.index,
dtype,
identity,
out.ptr(),
dst.numel(),
)?;
}
}
let outer_size: usize = shape[..dim].iter().product();
let dim_size = shape[dim];
let inner_size: usize = shape[dim + 1..].iter().product();
let src_dim_size = src.shape()[dim];
unsafe {
launch_scatter_reduce(
&client.context,
&client.stream,
client.device.index,
dtype,
src_contig.ptr(),
index_contig.ptr(),
out.ptr(),
dim,
outer_size,
dim_size,
inner_size,
src_dim_size,
cuda_op,
)?;
}
if matches!(op, ScatterReduceOp::Mean) {
if !matches!(dtype, DType::F32 | DType::F64) {
return Err(Error::UnsupportedDType {
dtype,
op: "scatter_reduce_mean",
});
}
let count = Tensor::<CudaRuntime>::empty(shape, dtype, &client.device);
unsafe {
launch_fill_with_f64(
&client.context,
&client.stream,
client.device.index,
dtype,
0.0,
count.ptr(),
dst.numel(),
)?;
}
if include_self {
unsafe {
launch_fill_with_f64(
&client.context,
&client.stream,
client.device.index,
dtype,
1.0,
count.ptr(),
dst.numel(),
)?;
}
}
unsafe {
launch_scatter_reduce_count(
&client.context,
&client.stream,
client.device.index,
dtype,
index_contig.ptr(),
count.ptr(),
dim,
outer_size,
dim_size,
inner_size,
src_dim_size,
)?;
}
let result = Tensor::<CudaRuntime>::empty(shape, dtype, &client.device);
unsafe {
launch_scatter_reduce_mean_div(
&client.context,
&client.stream,
client.device.index,
dtype,
out.ptr(),
count.ptr(),
result.ptr(),
dst.numel(),
)?;
}
return Ok(result);
}
Ok(out)
}
pub fn gather_nd(
client: &CudaClient,
input: &Tensor<CudaRuntime>,
indices: &Tensor<CudaRuntime>,
) -> Result<Tensor<CudaRuntime>> {
let dtype = input.dtype();
let input_shape = input.shape();
let indices_i64 = normalize_indices_to_i64(client, indices)?;
let indices_shape = indices_i64.shape();
if indices_shape.is_empty() {
return Err(Error::ShapeMismatch {
expected: vec![1],
got: indices_shape.to_vec(),
});
}
let indices_ndim = indices_shape.len();
let index_depth = indices_shape[indices_ndim - 1];
if index_depth > input_shape.len() {
return Err(Error::InvalidDimension {
dim: index_depth as isize,
ndim: input_shape.len(),
});
}
let mut out_shape: Vec<usize> = indices_shape[..indices_ndim - 1].to_vec();
out_shape.extend_from_slice(&input_shape[index_depth..]);
if out_shape.is_empty() {
out_shape.push(1);
}
let num_slices: usize = indices_shape[..indices_ndim - 1].iter().product();
let num_slices = num_slices.max(1);
let slice_size: usize = input_shape[index_depth..].iter().product();
let slice_size = slice_size.max(1);
let input_contig = ensure_contiguous(input);
let indices_contig = ensure_contiguous(&indices_i64);
let out = Tensor::<CudaRuntime>::empty(&out_shape, dtype, &client.device);
let input_shape_u32: Vec<u32> = input_shape.iter().map(|&s| s as u32).collect();
let input_strides: Vec<usize> = compute_contiguous_strides(input_shape);
let input_strides_u32: Vec<u32> = input_strides.iter().map(|&s| s as u32).collect();
let ndim = input_shape.len();
let shape_bytes = ndim * std::mem::size_of::<u32>();
let shape_ptr = CudaRuntime::allocate(shape_bytes, &client.device)?;
let strides_ptr = CudaRuntime::allocate(shape_bytes, &client.device)?;
CudaRuntime::copy_to_device(
bytemuck::cast_slice(&input_shape_u32),
shape_ptr,
&client.device,
)?;
CudaRuntime::copy_to_device(
bytemuck::cast_slice(&input_strides_u32),
strides_ptr,
&client.device,
)?;
let result = unsafe {
launch_gather_nd(
&client.context,
&client.stream,
client.device.index,
dtype,
input_contig.ptr(),
indices_contig.ptr(),
out.ptr(),
shape_ptr,
strides_ptr,
num_slices,
slice_size,
index_depth,
ndim,
)
};
CudaRuntime::deallocate(shape_ptr, shape_bytes, &client.device);
CudaRuntime::deallocate(strides_ptr, shape_bytes, &client.device);
result?;
Ok(out)
}
pub fn bincount(
client: &CudaClient,
input: &Tensor<CudaRuntime>,
weights: Option<&Tensor<CudaRuntime>>,
minlength: usize,
) -> Result<Tensor<CudaRuntime>> {
if input.ndim() != 1 {
return Err(Error::ShapeMismatch {
expected: vec![input.numel()],
got: input.shape().to_vec(),
});
}
let input_dtype = input.dtype();
if !matches!(input_dtype, DType::I32 | DType::I64) {
return Err(Error::DTypeMismatch {
lhs: DType::I64,
rhs: input_dtype,
});
}
let weights_dtype = if let Some(w) = weights {
if w.shape() != input.shape() {
return Err(Error::ShapeMismatch {
expected: input.shape().to_vec(),
got: w.shape().to_vec(),
});
}
Some(w.dtype())
} else {
None
};
let out_dtype = weights_dtype.unwrap_or(DType::I64);
let input_contig = ensure_contiguous(input);
let numel = input.numel();
let input_f64 = client.cast(input, DType::F64)?;
let max_tensor = client.max(&input_f64, &[0], false)?;
let max_val = max_tensor.item::<f64>()? as i64;
if max_val < 0 {
return Err(Error::InvalidArgument {
arg: "input",
reason: "bincount requires non-negative values".to_string(),
});
}
let output_len = ((max_val as usize) + 1).max(minlength);
let out = Tensor::<CudaRuntime>::empty(&[output_len], out_dtype, &client.device);
unsafe {
launch_fill_with_f64(
&client.context,
&client.stream,
client.device.index,
out_dtype,
0.0,
out.ptr(),
output_len,
)?;
}
let weights_contig = weights.map(ensure_contiguous);
let weights_ptr = weights_contig.as_ref().map(|w| w.ptr());
unsafe {
launch_bincount_weighted(
&client.context,
&client.stream,
client.device.index,
input_dtype,
weights_dtype,
input_contig.ptr(),
weights_ptr,
out.ptr(),
numel,
output_len,
)?;
}
Ok(out)
}