Skip to main content

burn_cubecl/kernel/index/
slice_assign.rs

1use crate::{
2    CubeRuntime,
3    kernel::utils::{linear_view, shape_divmod},
4    tensor::CubeTensor,
5};
6use cubecl::{
7    calculate_cube_count_elemwise, intrinsic,
8    prelude::*,
9    std::{FastDivmod, FastDivmodArgs, tensor::layout::linear::LinearView},
10};
11
12#[cube(launch_unchecked)]
13fn slice_assign_kernel<E: Numeric>(
14    input: &mut Tensor<Line<E>>,
15    value: &LinearView<Line<E>>,
16    slice_shape: Sequence<FastDivmod<usize>>,
17    slice_offsets: Sequence<usize>,
18    #[define(E)] _dtype: StorageType,
19) {
20    if !value.is_in_bounds(ABSOLUTE_POS) {
21        terminate!()
22    }
23
24    let rank = comptime!(slice_shape.len());
25
26    let line_size = input.line_size();
27    let mut offset_remainder = ABSOLUTE_POS * line_size;
28    let mut offset_input = 0;
29
30    #[allow(clippy::explicit_counter_loop)]
31    #[unroll]
32    for i in 0..rank {
33        let dim = rank - i - 1;
34        let (rem, offset_local) = slice_shape[dim].div_mod(offset_remainder);
35
36        let range_start = slice_offsets[dim];
37        let offset_local_input = offset_local + range_start;
38
39        offset_input += offset_local_input * input.stride(dim);
40        offset_remainder = rem;
41    }
42
43    // Value tensor is accessed linearly since it's a LinearView
44    input[offset_input / line_size] = value[ABSOLUTE_POS];
45}
46
47/// Kernel for slice assign with steps
48#[cube(launch_unchecked)]
49fn slice_assign_with_steps_kernel<E: Numeric>(
50    input: &mut Tensor<E>,
51    value: &LinearView<E>,
52    value_shape: Sequence<FastDivmod<usize>>,
53    starts: Sequence<usize>,
54    ends: Sequence<usize>,
55    steps: Sequence<i32>,
56    #[define(E)] _dtype: StorageType,
57) {
58    if !value.is_in_bounds(ABSOLUTE_POS) {
59        terminate!();
60    }
61
62    let rank = comptime![value_shape.len()];
63    let mut value_offset = ABSOLUTE_POS;
64    let mut input_offset = 0;
65
66    // Calculate the input offset based on value position and slice info
67    #[unroll]
68    for i in 0..rank {
69        // Iterate in reverse to use divmod
70        let dim = rank - i - 1;
71        let start = starts[dim];
72        let end = ends[dim];
73        let step = steps[dim];
74
75        let (rem, value_idx) = value_shape[dim].div_mod(value_offset);
76        value_offset = rem;
77
78        let input_idx = if step > 0 {
79            // Forward stepping
80            start + value_idx * (step as usize)
81        } else if step < 0 {
82            // Backward stepping - start from end-1
83            // For negative steps, we iterate backwards through the selected indices
84            let abs_step = (-step) as usize;
85            let end_minus_1 = end - 1;
86            end_minus_1 - value_idx * abs_step
87        } else {
88            // step == 0, shouldn't happen
89            value_idx
90        };
91
92        input_offset += input_idx * input.stride(dim);
93    }
94
95    input[input_offset] = value[ABSOLUTE_POS];
96}
97
98pub(crate) fn slice_assign<R: CubeRuntime>(
99    tensor: CubeTensor<R>,
100    indices: &[burn_backend::Slice],
101    value: CubeTensor<R>,
102) -> CubeTensor<R> {
103    // Check if any slice has non-unit step
104    let has_non_unit_step = indices.iter().any(|s| s.step != 1 && s.step != 0);
105
106    if has_non_unit_step {
107        // Use slice_assign_with_steps
108        return slice_assign_with_steps(tensor, indices, value);
109    }
110
111    let client = tensor.client.clone();
112    let tensor = match tensor.can_mut() {
113        true => tensor,
114        false => tensor.copy(),
115    };
116    let ndims = tensor.shape.num_dims();
117
118    let line_size = if tensor.strides[ndims - 1] == 1 && value.strides[ndims - 1] == 1 {
119        let last = indices
120            .get(ndims - 1)
121            .cloned()
122            .unwrap_or(burn_backend::Slice {
123                start: 0,
124                end: Some(tensor.shape[ndims - 1] as isize),
125                step: 1,
126            });
127        let end = last.end.unwrap_or(tensor.shape[ndims - 1] as isize);
128        let shape = (end - last.start) as usize;
129        let offset = last.start as usize;
130        *R::supported_line_sizes()
131            .iter()
132            .filter(|it| {
133                let it = **it;
134                shape.is_multiple_of(it)
135                    && strides_compatible(&tensor.strides, it)
136                    && strides_compatible(&value.strides, it)
137                    && offset.is_multiple_of(it)
138            })
139            .max()
140            .unwrap_or(&1)
141    } else {
142        1
143    };
144
145    let mut shape = SequenceArg::<R, FastDivmod<usize>>::new();
146    let mut offsets = SequenceArg::<R, usize>::new();
147
148    for i in 0..ndims {
149        let slice = indices.get(i).cloned().unwrap_or(burn_backend::Slice {
150            start: 0,
151            end: Some(tensor.shape[i] as isize),
152            step: 1,
153        });
154        let start = slice.start as usize;
155        let end = slice.end.unwrap_or(tensor.shape[i] as isize);
156        let length = (end - slice.start) as usize;
157
158        shape.push(FastDivmodArgs::<usize>::new(&client, length));
159        offsets.push(ScalarArg::new(start));
160    }
161
162    let working_units = value.shape.num_elements() / line_size;
163    let cube_dim = CubeDim::new(&tensor.client, working_units);
164    let cube_count = calculate_cube_count_elemwise(&tensor.client, working_units, cube_dim);
165
166    unsafe {
167        slice_assign_kernel::launch_unchecked(
168            &tensor.client,
169            cube_count,
170            cube_dim,
171            tensor.as_tensor_arg(line_size),
172            linear_view(&value, line_size),
173            shape,
174            offsets,
175            tensor.dtype.into(),
176        )
177        .expect("Kernel to never fail");
178    }
179
180    tensor
181}
182
183/// Slice assign with steps support
184///
185/// This function handles slice assignment with arbitrary step values, including negative steps.
186/// It follows NumPy/PyTorch semantics where values[i] is assigned to selected_indices[i].
187///
188/// For example, with s![0..6;-1] which selects indices [5,4,3,2,1,0]:
189/// - values[0] goes to index 5
190/// - values[1] goes to index 4
191/// - etc.
192pub(crate) fn slice_assign_with_steps<R: CubeRuntime>(
193    tensor: CubeTensor<R>,
194    slices: &[burn_backend::Slice],
195    value: CubeTensor<R>,
196) -> CubeTensor<R> {
197    let tensor = match tensor.can_mut() {
198        true => tensor,
199        false => tensor.copy(),
200    };
201
202    // Prepare sequences for kernel
203    let mut starts = SequenceArg::<R, usize>::new();
204    let mut ends = SequenceArg::<R, usize>::new();
205    let mut steps = SequenceArg::<R, i32>::new();
206
207    for (dim, slice) in slices.iter().enumerate() {
208        let range = slice.to_range(tensor.shape[dim]);
209        starts.push(ScalarArg::new(range.start));
210        ends.push(ScalarArg::new(range.end));
211        steps.push(ScalarArg::new(slice.step as i32));
212    }
213
214    // Pad with default values if needed to match tensor dimensions
215    for dim in slices.len()..tensor.shape.num_dims() {
216        starts.push(ScalarArg::new(0));
217        ends.push(ScalarArg::new(tensor.shape[dim]));
218        steps.push(ScalarArg::new(1));
219    }
220
221    // Launch kernel
222    let working_units = value.shape.num_elements();
223    let cube_dim = CubeDim::new(&tensor.client, working_units);
224    let cube_count = calculate_cube_count_elemwise(&tensor.client, working_units, cube_dim);
225
226    unsafe {
227        slice_assign_with_steps_kernel::launch_unchecked(
228            &tensor.client,
229            cube_count,
230            cube_dim,
231            tensor.as_tensor_arg(1),
232            linear_view(&value, 1),
233            shape_divmod(&value),
234            starts,
235            ends,
236            steps,
237            tensor.dtype.into(),
238        )
239        .expect("Kernel to never fail");
240    }
241
242    tensor
243}
244
245fn strides_compatible(strides: &[usize], vec: usize) -> bool {
246    strides
247        .iter()
248        .all(|stride| *stride % vec == 0 || *stride == 1)
249}
250
251/// Helper function for unwrap
252#[allow(unused)]
253#[cube]
254fn unwrap(value: u32) -> comptime_type!(u32) {
255    intrinsic!(|_| value.constant().unwrap().as_u32())
256}