burn_cubecl/kernel/index/
slice.rs

1use crate::{CubeRuntime, element::CubeElement, ops::numeric::empty_device, tensor::CubeTensor};
2use burn_tensor::Shape;
3use cubecl::{calculate_cube_count_elemwise, prelude::*};
4use std::ops::Range;
5
6/// Slice a jit tensor with a set of ranges
7pub fn slice<R: CubeRuntime, E: CubeElement>(
8    tensor: CubeTensor<R>,
9    indices: &[Range<usize>],
10) -> CubeTensor<R> {
11    let mut dims = tensor.shape.dims.clone();
12    let mut offset_start = 0u64;
13    let mut offset_end = 0u64;
14
15    for i in 0..indices.len() {
16        offset_start += (tensor.strides[i] * indices[i].start) as u64;
17        offset_end += (tensor.strides[i] * (dims[i] - indices[i].end)) as u64;
18        dims[i] = indices[i].end - indices[i].start;
19    }
20
21    let offset_start = offset_start * E::cube_elem().size() as u64;
22    let offset_end = offset_end * E::cube_elem().size() as u64;
23
24    let memory_offset_alignment = tensor.client.properties().memory_properties().alignment;
25
26    if offset_start % memory_offset_alignment == 0u64
27        && offset_end % memory_offset_alignment == 0u64
28    {
29        CubeTensor::new(
30            tensor.client,
31            tensor
32                .handle
33                .offset_start(offset_start)
34                .offset_end(offset_end),
35            Shape::from(dims),
36            tensor.device,
37            tensor.strides,
38            tensor.dtype,
39        )
40    } else {
41        let shape_output = Shape::from(dims);
42        let output =
43            empty_device::<R, E>(tensor.client.clone(), tensor.device.clone(), shape_output);
44        slice_on_output::<R, E>(tensor, output, indices)
45    }
46}
47
48#[cube(launch_unchecked)]
49fn slice_kernel<E: CubePrimitive>(
50    input: &Tensor<E>,
51    output: &mut Tensor<E>,
52    indices: Sequence<u32>,
53    #[comptime] rank: u32,
54) {
55    if ABSOLUTE_POS >= output.len() {
56        terminate!();
57    }
58
59    let mut offset_input = 0;
60
61    #[unroll]
62    for i in 0..rank {
63        let range_start = *indices.index(i);
64        let offset_local = ABSOLUTE_POS / output.stride(i) % output.shape(i) + range_start;
65
66        offset_input += offset_local * input.stride(i);
67    }
68
69    output[ABSOLUTE_POS] = input[offset_input];
70}
71
72pub(crate) fn slice_on_output<R: CubeRuntime, E: CubeElement>(
73    tensor: CubeTensor<R>,
74    output: CubeTensor<R>,
75    indices: &[Range<usize>],
76) -> CubeTensor<R> {
77    let ndims = tensor.shape.num_dims();
78    let mut indices_sequence = SequenceArg::<R, u32>::new();
79
80    for i in 0..ndims {
81        let start = indices.get(i).map(|index| index.start).unwrap_or(0);
82        indices_sequence.push(ScalarArg::new(start as u32));
83    }
84
85    let cube_dim = CubeDim::default();
86    let cube_count = calculate_cube_count_elemwise(output.shape.num_elements(), cube_dim);
87
88    unsafe {
89        slice_kernel::launch_unchecked::<E, R>(
90            &tensor.client,
91            cube_count,
92            cube_dim,
93            tensor.as_tensor_arg::<E>(1),
94            output.as_tensor_arg::<E>(1),
95            indices_sequence,
96            ndims as u32,
97        )
98    };
99
100    output
101}