use crate::{
CubeRuntime,
kernel::utils::{address_type, shape_divmod},
tensor::CubeTensor,
};
use cubecl::{
calculate_cube_count_elemwise, intrinsic,
prelude::*,
std::{FastDivmod, tensor::layout::linear::LinearView},
};
#[cube(launch_unchecked, address_type = "dynamic")]
fn slice_assign_kernel<E: Numeric, N: Size>(
input: &mut Tensor<Vector<E, N>>,
value: &LinearView<Vector<E, N>>,
slice_shape: Sequence<FastDivmod<usize>>,
slice_offsets: Sequence<usize>,
#[define(E)] _dtype: StorageType,
) {
if !value.is_in_bounds(ABSOLUTE_POS) {
terminate!()
}
let rank = comptime!(slice_shape.len());
let line_size = input.vector_size();
let mut offset_remainder = ABSOLUTE_POS * line_size;
let mut offset_input = 0;
#[allow(clippy::explicit_counter_loop)]
#[unroll]
for i in 0..rank {
let dim = rank - i - 1;
let (rem, offset_local) = slice_shape[dim].div_mod(offset_remainder);
let range_start = slice_offsets[dim];
let offset_local_input = offset_local + range_start;
offset_input += offset_local_input * input.stride(dim);
offset_remainder = rem;
}
input[offset_input / line_size] = value[ABSOLUTE_POS];
}
#[cube(launch_unchecked, address_type = "dynamic")]
fn slice_assign_with_steps_kernel<E: Numeric>(
input: &mut Tensor<E>,
value: &LinearView<E>,
value_shape: Sequence<FastDivmod<usize>>,
starts: Sequence<usize>,
ends: Sequence<usize>,
steps: Sequence<i32>,
#[define(E)] _dtype: StorageType,
) {
if !value.is_in_bounds(ABSOLUTE_POS) {
terminate!();
}
let rank = comptime![value_shape.len()];
let mut value_offset = ABSOLUTE_POS;
let mut input_offset = 0;
#[unroll]
for i in 0..rank {
let dim = rank - i - 1;
let start = starts[dim];
let end = ends[dim];
let step = steps[dim];
let (rem, value_idx) = value_shape[dim].div_mod(value_offset);
value_offset = rem;
let input_idx = if step > 0 {
start + value_idx * (step as usize)
} else if step < 0 {
let abs_step = (-step) as usize;
let end_minus_1 = end - 1;
end_minus_1 - value_idx * abs_step
} else {
value_idx
};
input_offset += input_idx * input.stride(dim);
}
input[input_offset] = value[ABSOLUTE_POS];
}
pub(crate) fn slice_assign<R: CubeRuntime>(
tensor: CubeTensor<R>,
indices: &[burn_backend::Slice],
value: CubeTensor<R>,
) -> CubeTensor<R> {
let has_non_unit_step = indices.iter().any(|s| s.step != 1 && s.step != 0);
if has_non_unit_step {
return slice_assign_with_steps(tensor, indices, value);
}
let client = tensor.client.clone();
let tensor = match tensor.can_mut() && tensor.is_nonoverlapping() {
true => tensor,
false => tensor.copy(),
};
let ndims = tensor.meta.num_dims();
let vector_size =
if tensor.meta.strides()[ndims - 1] == 1 && value.meta.strides()[ndims - 1] == 1 {
let last = indices
.get(ndims - 1)
.cloned()
.unwrap_or(burn_backend::Slice {
start: 0,
end: Some(tensor.meta.shape()[ndims - 1] as isize),
step: 1,
});
let end = last.end.unwrap_or(tensor.meta.shape()[ndims - 1] as isize);
let shape = (end - last.start) as usize;
let offset = last.start as usize;
client
.io_optimized_vector_sizes(tensor.dtype.size())
.filter(|&it| {
shape.is_multiple_of(it)
&& strides_compatible(tensor.meta.strides(), it)
&& strides_compatible(value.meta.strides(), it)
&& offset.is_multiple_of(it)
})
.max()
.unwrap_or(1)
} else {
1
};
let mut shape = SequenceArg::<R, FastDivmod<usize>>::new();
let mut offsets = SequenceArg::<R, usize>::new();
for i in 0..ndims {
let slice = indices.get(i).cloned().unwrap_or(burn_backend::Slice {
start: 0,
end: Some(tensor.meta.shape()[i] as isize),
step: 1,
});
let start = slice.start as usize;
let end = slice.end.unwrap_or(tensor.meta.shape()[i] as isize);
let length = (end - slice.start) as usize;
shape.push(length);
offsets.push(start);
}
let working_units = value.meta.num_elements() / vector_size;
let cube_dim = CubeDim::new(&tensor.client, working_units);
let cube_count = calculate_cube_count_elemwise(&tensor.client, working_units, cube_dim);
unsafe {
slice_assign_kernel::launch_unchecked(
&tensor.client,
cube_count,
cube_dim,
address_type!(tensor, value),
vector_size,
tensor.clone().into_tensor_arg(),
value.into_linear_view(),
shape,
offsets,
tensor.dtype.into(),
)
};
tensor
}
pub(crate) fn slice_assign_with_steps<R: CubeRuntime>(
tensor: CubeTensor<R>,
slices: &[burn_backend::Slice],
value: CubeTensor<R>,
) -> CubeTensor<R> {
let tensor = match tensor.can_mut() && tensor.is_nonoverlapping() {
true => tensor,
false => tensor.copy(),
};
let mut starts = SequenceArg::<R, usize>::new();
let mut ends = SequenceArg::<R, usize>::new();
let mut steps = SequenceArg::<R, i32>::new();
for (dim, slice) in slices.iter().enumerate() {
let range = slice.to_range(tensor.meta.shape()[dim]);
starts.push(range.start);
ends.push(range.end);
steps.push(slice.step as i32);
}
for dim in slices.len()..tensor.meta.num_dims() {
starts.push(0);
ends.push(tensor.meta.shape[dim]);
steps.push(1);
}
let working_units = value.meta.num_elements();
let cube_dim = CubeDim::new(&tensor.client, working_units);
let cube_count = calculate_cube_count_elemwise(&tensor.client, working_units, cube_dim);
let shape = shape_divmod(&value);
unsafe {
slice_assign_with_steps_kernel::launch_unchecked(
&tensor.client,
cube_count,
cube_dim,
address_type!(tensor, value),
tensor.clone().into_tensor_arg(),
value.into_linear_view(),
shape,
starts,
ends,
steps,
tensor.dtype.into(),
);
}
tensor
}
fn strides_compatible(strides: &[usize], vec: usize) -> bool {
strides
.iter()
.all(|stride| *stride % vec == 0 || *stride == 1)
}
#[allow(unused)]
#[cube]
fn unwrap(value: u32) -> comptime_type!(u32) {
intrinsic!(|_| value.constant().unwrap().as_u32())
}