burn_cubecl/kernel/index/
slice.rs

1use crate::{
2    CubeRuntime,
3    kernel::utils::{linear_view, shape_divmod},
4    ops::numeric::empty_device_dtype,
5    tensor::CubeTensor,
6};
7use burn_backend::Slice;
8use cubecl::{
9    calculate_cube_count_elemwise, intrinsic,
10    prelude::*,
11    std::{FastDivmod, tensor::layout::linear::LinearView},
12};
13use std::ops::Range;
14
15/// Slice a jit tensor with a set of ranges
16pub fn slice<R: CubeRuntime>(tensor: CubeTensor<R>, indices: &[Range<usize>]) -> CubeTensor<R> {
17    let mut dims = tensor.shape.clone();
18    let mut offset_start = 0u64;
19    let mut offset_end = 0u64;
20
21    for i in 0..indices.len() {
22        offset_start += (tensor.strides[i] * indices[i].start) as u64;
23        offset_end += (tensor.strides[i] * (dims[i] - indices[i].end)) as u64;
24        dims[i] = indices[i].end - indices[i].start;
25    }
26
27    let offset_start = offset_start * tensor.dtype.size() as u64;
28    let offset_end = offset_end * tensor.dtype.size() as u64;
29
30    let memory_offset_alignment = tensor.client.properties().memory.alignment;
31
32    if offset_start.is_multiple_of(memory_offset_alignment)
33        && offset_end.is_multiple_of(memory_offset_alignment)
34    {
35        CubeTensor::new(
36            tensor.client,
37            tensor
38                .handle
39                .offset_start(offset_start)
40                .offset_end(offset_end),
41            dims,
42            tensor.device,
43            tensor.strides,
44            tensor.dtype,
45        )
46    } else {
47        let output = empty_device_dtype(
48            tensor.client.clone(),
49            tensor.device.clone(),
50            dims,
51            tensor.dtype,
52        );
53        slice_on_output(tensor, output, indices)
54    }
55}
56
57#[cube(launch_unchecked)]
58fn slice_kernel<E: Numeric>(
59    input: &Tensor<E>,
60    output: &mut LinearView<E, ReadWrite>,
61    out_shape: Sequence<FastDivmod<usize>>,
62    indices: Sequence<usize>,
63    #[define(E)] _dtype: StorageType,
64) {
65    if !output.is_in_bounds(ABSOLUTE_POS) {
66        terminate!();
67    }
68
69    let rank = comptime![out_shape.len()];
70    let mut offset_output = ABSOLUTE_POS;
71    let mut offset_input = 0;
72
73    #[unroll]
74    for i in 0..rank {
75        // Iterate in reverse to use divmod
76        let dim = rank - i - 1;
77
78        let range_start = indices[dim];
79        let (rem, offset_local) = out_shape[dim].div_mod(offset_output);
80        offset_output = rem;
81
82        let offset_local = offset_local + range_start;
83
84        offset_input += offset_local * input.stride(dim);
85    }
86
87    output[ABSOLUTE_POS] = input[offset_input];
88}
89
90pub(crate) fn slice_on_output<R: CubeRuntime>(
91    tensor: CubeTensor<R>,
92    output: CubeTensor<R>,
93    indices: &[Range<usize>],
94) -> CubeTensor<R> {
95    let ndims = tensor.shape.num_dims();
96    let mut indices_sequence = SequenceArg::<R, usize>::new();
97
98    for i in 0..ndims {
99        let start = indices.get(i).map(|index| index.start).unwrap_or(0);
100        indices_sequence.push(ScalarArg::new(start));
101    }
102
103    let working_units = output.shape.num_elements();
104    let cube_dim = CubeDim::new(&tensor.client, working_units);
105    let cube_count = calculate_cube_count_elemwise(&tensor.client, working_units, cube_dim);
106
107    unsafe {
108        slice_kernel::launch_unchecked(
109            &tensor.client,
110            cube_count,
111            cube_dim,
112            tensor.as_tensor_arg(1),
113            linear_view(&output, 1),
114            shape_divmod(&output),
115            indices_sequence,
116            tensor.dtype.into(),
117        )
118        .expect("Kernel to never fail");
119    };
120
121    output
122}
123
124/// Kernel for slicing with steps
125#[cube(launch_unchecked)]
126fn slice_with_steps_kernel<E: Numeric>(
127    input: &Tensor<E>,
128    output: &mut LinearView<E, ReadWrite>,
129    out_shape: Sequence<FastDivmod<usize>>,
130    starts: Sequence<usize>,
131    ends: Sequence<usize>,
132    steps: Sequence<i32>,
133    #[define(E)] _dtype: StorageType,
134) {
135    if !output.is_in_bounds(ABSOLUTE_POS) {
136        terminate!();
137    }
138
139    let rank = comptime![out_shape.len()];
140    let mut output_offset = ABSOLUTE_POS;
141    let mut input_offset = 0;
142
143    // Calculate the input offset based on output position and slice info
144    #[unroll]
145    for i in 0..rank {
146        // Iterate in reverse to use divmod
147        let dim = rank - i - 1;
148        let start = starts[dim];
149        let end = ends[dim];
150        let step = steps[dim];
151
152        let (rem, output_idx) = out_shape[dim].div_mod(output_offset);
153        output_offset = rem;
154
155        let input_idx = if step > 0 {
156            // Forward stepping
157            start + output_idx * (step as usize)
158        } else {
159            // Backward stepping - start from end-1
160            let abs_step = (-step) as usize;
161            let end_minus_1 = end - 1;
162            end_minus_1 - output_idx * abs_step
163        };
164
165        input_offset += input_idx * input.stride(dim);
166    }
167
168    output[ABSOLUTE_POS] = input[input_offset];
169}
170
171/// Slice a tensor with steps
172pub fn slice_with_steps<R: CubeRuntime>(tensor: CubeTensor<R>, slices: &[Slice]) -> CubeTensor<R> {
173    // Check if all steps are 1 - if so, use the optimized regular slice
174    let all_steps_one = slices.iter().all(|info| info.step == 1);
175
176    if all_steps_one {
177        // Convert Slice to Range for step=1
178        let simple_ranges: Vec<Range<usize>> = slices
179            .iter()
180            .enumerate()
181            .map(|(i, slice)| slice.to_range(tensor.shape[i]))
182            .collect();
183        return slice(tensor, &simple_ranges);
184    }
185
186    // Calculate output shape
187    let shape_output = tensor.shape.clone().slice(slices).unwrap();
188
189    // Create output tensor
190    let output = empty_device_dtype(
191        tensor.client.clone(),
192        tensor.device.clone(),
193        shape_output.clone(),
194        tensor.dtype,
195    );
196
197    // Prepare three separate sequences for kernel
198    let mut starts = SequenceArg::<R, usize>::new();
199    let mut ends = SequenceArg::<R, usize>::new();
200    let mut steps = SequenceArg::<R, i32>::new();
201
202    for (dim, slice) in slices.iter().enumerate() {
203        let range = slice.to_range(tensor.shape[dim]);
204        starts.push(ScalarArg::new(range.start));
205        ends.push(ScalarArg::new(range.end));
206        steps.push(ScalarArg::new(slice.step as i32));
207    }
208
209    // Pad with default values if needed to match tensor dimensions
210    for dim in slices.len()..tensor.shape.num_dims() {
211        starts.push(ScalarArg::new(0));
212        ends.push(ScalarArg::new(tensor.shape[dim]));
213        steps.push(ScalarArg::new(1));
214    }
215
216    // Launch kernel
217    let working_units = shape_output.num_elements();
218    let cube_dim = CubeDim::new(&tensor.client, working_units);
219    let cube_count = calculate_cube_count_elemwise(&tensor.client, working_units, cube_dim);
220
221    unsafe {
222        slice_with_steps_kernel::launch_unchecked(
223            &tensor.client,
224            cube_count,
225            cube_dim,
226            tensor.as_tensor_arg(1),
227            linear_view(&output, 1),
228            shape_divmod(&output),
229            starts,
230            ends,
231            steps,
232            tensor.dtype.into(),
233        )
234        .expect("Kernel to never fail");
235    }
236
237    output
238}
239
240/// This is annoying and we need to find a way to do this automatically at some point
241#[allow(unused)]
242#[cube]
243fn unwrap(value: u32) -> comptime_type!(u32) {
244    intrinsic!(|_| value.constant().unwrap().as_u32())
245}