use crate::{
CubeRuntime,
kernel::{
AddOp, AssignOp, BinaryMaxOp, BinaryMinOp, BinaryOp, BinaryOpFamily, MulOp,
utils::{address_type, shape_divmod_range},
},
tensor::CubeTensor,
};
use burn_backend::tensor::IndexingUpdateOp;
use cubecl::std::tensor::layout::linear::LinearView;
use cubecl::{CubeDim, calculate_cube_count_elemwise};
use cubecl::{prelude::*, std::FastDivmod};
#[cube(launch_unchecked, address_type = "dynamic")]
fn scatter_nd_kernel<T: Numeric, I: Int, Op: BinaryOpFamily>(
data: &mut Tensor<T>,
indices: &LinearView<I>,
values: &Tensor<T>,
data_slice_shape: Sequence<FastDivmod<usize>>,
values_shape: Sequence<FastDivmod<usize>>,
slice_size: usize,
k: usize,
working_units: usize,
#[define(T, I)] _dtypes: [StorageType; 2],
) {
if ABSOLUTE_POS >= working_units {
terminate!();
}
let slice_offset = ABSOLUTE_POS % slice_size;
let update_idx = ABSOLUTE_POS / slice_size;
let idx_base = update_idx * k;
let mut base_offset = 0usize;
for j in 0..k {
let idx_val = usize::cast_from(indices[idx_base + j]);
base_offset += idx_val * data.stride(j);
}
let slice_rank = data_slice_shape.len().comptime();
let mut data_slice_offset = 0usize;
let mut remainder = slice_offset;
#[unroll]
for i in 0..slice_rank {
let dim = slice_rank - i - 1;
let (rem, coord) = data_slice_shape[dim].div_mod(remainder);
remainder = rem;
data_slice_offset += coord * data.stride(k + dim);
}
let val_rank = values_shape.len().comptime();
let mut val_offset = update_idx * values.stride(0);
let mut remainder_v = slice_offset;
#[unroll]
for i in 0..val_rank {
let dim = val_rank - i - 1;
let (rem, coord) = values_shape[dim].div_mod(remainder_v);
remainder_v = rem;
val_offset += coord * values.stride(1 + dim);
}
let data_idx = base_offset + data_slice_offset;
let result = Op::BinaryOp::<T, Const<1>>::execute(
Vector::cast_from(data[data_idx]),
Vector::cast_from(values[val_offset]),
);
data[data_idx] = result[0];
}
pub(crate) fn scatter_nd<R: CubeRuntime>(
tensor: CubeTensor<R>,
indices: CubeTensor<R>,
values: CubeTensor<R>,
reduction: IndexingUpdateOp,
) -> CubeTensor<R> {
let tensor = match tensor.can_mut() && tensor.is_nonoverlapping() {
true => tensor,
false => tensor.copy(),
};
let data_shape = &tensor.meta.shape;
let idx_shape = &indices.meta.shape;
let m = idx_shape.num_dims();
let k = idx_shape[m - 1];
let num_updates: usize = idx_shape.as_slice()[..m - 1].iter().product();
let slice_size: usize = data_shape.as_slice()[k..].iter().product();
let working_units = num_updates * slice_size;
let cube_dim = CubeDim::new(&indices.client, working_units);
let cube_count = calculate_cube_count_elemwise(&indices.client, working_units, cube_dim);
let (tensor_dtype, indices_dtype) = (tensor.dtype, indices.dtype);
let launch = match reduction {
IndexingUpdateOp::Assign => scatter_nd_kernel::launch_unchecked::<AssignOp, R>,
IndexingUpdateOp::Add => scatter_nd_kernel::launch_unchecked::<AddOp, R>,
IndexingUpdateOp::Mul => scatter_nd_kernel::launch_unchecked::<MulOp, R>,
IndexingUpdateOp::Min => scatter_nd_kernel::launch_unchecked::<BinaryMinOp, R>,
IndexingUpdateOp::Max => scatter_nd_kernel::launch_unchecked::<BinaryMaxOp, R>,
};
let data_slice_shape = shape_divmod_range(&tensor, k..data_shape.num_dims());
let values_slice_shape = shape_divmod_range(&values, 1..values.meta.shape.num_dims());
unsafe {
launch(
&tensor.client.clone(),
cube_count,
cube_dim,
address_type!(tensor, indices, values),
tensor.clone().into_tensor_arg(),
indices.into_linear_view(),
values.into_tensor_arg(),
data_slice_shape,
values_slice_shape,
slice_size,
k,
working_units,
[tensor_dtype.into(), indices_dtype.into()],
)
}
tensor
}