Skip to main content

burn_cubecl/kernel/index/
slice.rs

1use crate::{
2    CubeRuntime,
3    kernel::utils::{address_type, shape_divmod},
4    ops::numeric::empty_device_dtype,
5    tensor::CubeTensor,
6};
7use burn_backend::{Slice, TensorMetadata};
8use burn_std::{Metadata, SliceOps};
9use cubecl::{
10    calculate_cube_count_elemwise, intrinsic,
11    prelude::*,
12    std::{FastDivmod, tensor::layout::linear::LinearView},
13};
14use std::ops::Range;
15
16/// Slice a jit tensor with a set of ranges
17pub fn slice<R: CubeRuntime>(tensor: CubeTensor<R>, indices: &[Range<usize>]) -> CubeTensor<R> {
18    let mut dims = tensor.shape();
19    let mut offset_start = 0u64;
20    let mut offset_end = 0u64;
21
22    for i in 0..indices.len() {
23        offset_start += (tensor.meta.strides()[i] * indices[i].start) as u64;
24        offset_end += (tensor.meta.strides()[i] * (dims[i] - indices[i].end)) as u64;
25        dims[i] = indices[i].end - indices[i].start;
26    }
27
28    let offset_start = offset_start * tensor.dtype.size() as u64;
29    let offset_end = offset_end * tensor.dtype.size() as u64;
30
31    let memory_offset_alignment = tensor.client.properties().memory.alignment;
32
33    if offset_start.is_multiple_of(memory_offset_alignment)
34        && offset_end.is_multiple_of(memory_offset_alignment)
35    {
36        CubeTensor::new(
37            tensor.client.clone(),
38            tensor
39                .handle
40                .clone()
41                .offset_start(offset_start)
42                .offset_end(offset_end),
43            Metadata::new(dims, tensor.meta.strides.clone()),
44            tensor.device.clone(),
45            tensor.dtype,
46        )
47    } else {
48        let output = empty_device_dtype(
49            tensor.client.clone(),
50            tensor.device.clone(),
51            dims,
52            tensor.dtype,
53        );
54        slice_on_output(tensor, output, indices)
55    }
56}
57
58#[cube(launch_unchecked, address_type = "dynamic")]
59fn slice_kernel<E: Numeric>(
60    input: &Tensor<E>,
61    output: &mut LinearView<E, ReadWrite>,
62    out_shape: Sequence<FastDivmod<usize>>,
63    indices: Sequence<usize>,
64    #[define(E)] _dtype: StorageType,
65) {
66    if !output.is_in_bounds(ABSOLUTE_POS) {
67        terminate!();
68    }
69
70    let rank = comptime![out_shape.len()];
71    let mut offset_output = ABSOLUTE_POS;
72    let mut offset_input = 0;
73
74    #[unroll]
75    for i in 0..rank {
76        // Iterate in reverse to use divmod
77        let dim = rank - i - 1;
78
79        let range_start = indices[dim];
80        let (rem, offset_local) = out_shape[dim].div_mod(offset_output);
81        offset_output = rem;
82
83        let offset_local = offset_local + range_start;
84
85        offset_input += offset_local * input.stride(dim);
86    }
87
88    output[ABSOLUTE_POS] = input[offset_input];
89}
90
91pub(crate) fn slice_on_output<R: CubeRuntime>(
92    tensor: CubeTensor<R>,
93    output: CubeTensor<R>,
94    indices: &[Range<usize>],
95) -> CubeTensor<R> {
96    let ndims = tensor.meta.num_dims();
97    let mut indices_sequence = SequenceArg::<R, usize>::new();
98
99    for i in 0..ndims {
100        let start = indices.get(i).map(|index| index.start).unwrap_or(0);
101        indices_sequence.push(start);
102    }
103
104    let working_units = output.meta.num_elements();
105    let cube_dim = CubeDim::new(&tensor.client, working_units);
106    let cube_count = calculate_cube_count_elemwise(&tensor.client, working_units, cube_dim);
107    let dtype = tensor.dtype;
108
109    unsafe {
110        slice_kernel::launch_unchecked(
111            &output.client,
112            cube_count,
113            cube_dim,
114            address_type!(tensor, output),
115            tensor.into_tensor_arg(),
116            output.clone().into_linear_view(),
117            shape_divmod(&output),
118            indices_sequence,
119            dtype.into(),
120        )
121    };
122
123    output
124}
125
126/// Kernel for slicing with steps
127#[cube(launch_unchecked, address_type = "dynamic")]
128fn slice_with_steps_kernel<E: Numeric>(
129    input: &Tensor<E>,
130    output: &mut LinearView<E, ReadWrite>,
131    out_shape: Sequence<FastDivmod<usize>>,
132    starts: Sequence<usize>,
133    ends: Sequence<usize>,
134    steps: Sequence<i32>,
135    #[define(E)] _dtype: StorageType,
136) {
137    if !output.is_in_bounds(ABSOLUTE_POS) {
138        terminate!();
139    }
140
141    let rank = comptime![out_shape.len()];
142    let mut output_offset = ABSOLUTE_POS;
143    let mut input_offset = 0;
144
145    // Calculate the input offset based on output position and slice info
146    #[unroll]
147    for i in 0..rank {
148        // Iterate in reverse to use divmod
149        let dim = rank - i - 1;
150        let start = starts[dim];
151        let end = ends[dim];
152        let step = steps[dim];
153
154        let (rem, output_idx) = out_shape[dim].div_mod(output_offset);
155        output_offset = rem;
156
157        let input_idx = if step > 0 {
158            // Forward stepping
159            start + output_idx * (step as usize)
160        } else {
161            // Backward stepping - start from end-1
162            let abs_step = (-step) as usize;
163            let end_minus_1 = end - 1;
164            end_minus_1 - output_idx * abs_step
165        };
166
167        input_offset += input_idx * input.stride(dim);
168    }
169
170    output[ABSOLUTE_POS] = input[input_offset];
171}
172
173/// Slice a tensor with steps
174pub fn slice_with_steps<R: CubeRuntime>(tensor: CubeTensor<R>, slices: &[Slice]) -> CubeTensor<R> {
175    // Check if all steps are 1 - if so, use the optimized regular slice
176    let all_steps_one = slices.iter().all(|info| info.step == 1);
177
178    if all_steps_one {
179        // Convert Slice to Range for step=1
180        let simple_ranges: Vec<Range<usize>> = slices
181            .iter()
182            .enumerate()
183            .map(|(i, slice)| slice.to_range(tensor.meta.shape()[i]))
184            .collect();
185        return slice(tensor, &simple_ranges);
186    }
187
188    // Calculate output shape
189    let shape_output = tensor.shape().slice(slices).unwrap();
190
191    // Create output tensor
192    let output = empty_device_dtype(
193        tensor.client.clone(),
194        tensor.device.clone(),
195        shape_output.clone(),
196        tensor.dtype,
197    );
198
199    // Prepare three separate sequences for kernel
200    let mut starts = SequenceArg::<R, usize>::new();
201    let mut ends = SequenceArg::<R, usize>::new();
202    let mut steps = SequenceArg::<R, i32>::new();
203
204    for (dim, slice) in slices.iter().enumerate() {
205        let range = slice.to_range(tensor.meta.shape()[dim]);
206        starts.push(range.start);
207        ends.push(range.end);
208        steps.push(slice.step as i32);
209    }
210
211    // Pad with default values if needed to match tensor dimensions
212    for dim in slices.len()..tensor.meta.num_dims() {
213        starts.push(0);
214        ends.push(tensor.meta.shape[dim]);
215        steps.push(1);
216    }
217
218    // Launch kernel
219    let working_units = shape_output.num_elements();
220    let cube_dim = CubeDim::new(&tensor.client, working_units);
221    let cube_count = calculate_cube_count_elemwise(&tensor.client, working_units, cube_dim);
222    let dtype = tensor.dtype;
223
224    unsafe {
225        slice_with_steps_kernel::launch_unchecked(
226            &output.client,
227            cube_count,
228            cube_dim,
229            address_type!(tensor, output),
230            tensor.into_tensor_arg(),
231            output.clone().into_linear_view(),
232            shape_divmod(&output),
233            starts,
234            ends,
235            steps,
236            dtype.into(),
237        );
238    }
239
240    output
241}
242
243/// This is annoying and we need to find a way to do this automatically at some point
244#[allow(unused)]
245#[cube]
246fn unwrap(value: u32) -> comptime_type!(u32) {
247    intrinsic!(|_| value.constant().unwrap().as_u32())
248}