use crate::{
CubeRuntime,
kernel::utils::{address_type, broadcast_strides, shape_divmod},
ops::numeric::empty_device_dtype,
tensor::CubeTensor,
};
use burn_backend::TensorMetadata;
use cubecl::frontend::{ABSOLUTE_POS, Numeric, Tensor};
use cubecl::std::{FastDivmod, tensor::index_offset_contiguous_fastdivmod};
use cubecl::{CubeDim, std::tensor::layout::linear::LinearView};
use cubecl::{calculate_cube_count_elemwise, prelude::*};
#[cube(launch_unchecked, address_type = "dynamic")]
fn gather_kernel<T: Numeric, I: Numeric>(
input: &Tensor<T>,
indices: &LinearView<I>,
output: &mut LinearView<T, ReadWrite>,
in_strides: Sequence<usize>, out_shape: Sequence<FastDivmod<usize>>,
dim: usize,
#[define(T, I)] _dtypes: [StorageType; 2],
) {
if !indices.is_in_bounds(ABSOLUTE_POS) {
terminate!();
}
let mut offset = index_offset_contiguous_fastdivmod(
ABSOLUTE_POS,
&out_shape,
&in_strides,
input.vector_size(),
);
offset += usize::cast_from(indices[ABSOLUTE_POS]) * input.stride(dim);
output[ABSOLUTE_POS] = input[offset];
}
pub(crate) fn gather<R: CubeRuntime>(
dim: usize,
tensor: CubeTensor<R>,
indices: CubeTensor<R>,
) -> CubeTensor<R> {
let shape_output = indices.shape();
let total_elem = shape_output.num_elements();
let output = empty_device_dtype(
tensor.client.clone(),
tensor.device.clone(),
shape_output,
tensor.dtype,
);
let cube_dim = CubeDim::new(&tensor.client, total_elem);
let cube_count = calculate_cube_count_elemwise(&tensor.client, total_elem, cube_dim);
let mut in_strides = broadcast_strides(&output, &tensor);
in_strides.values[dim] = 0;
let (dtype, indices_dtype) = (tensor.dtype, indices.dtype);
unsafe {
gather_kernel::launch_unchecked(
&output.client,
cube_count,
cube_dim,
address_type!(tensor, indices, output),
tensor.into_tensor_arg(),
indices.into_linear_view(),
output.clone().into_linear_view(),
in_strides,
shape_divmod(&output),
dim,
[dtype.into(), indices_dtype.into()],
)
}
output
}