burn_cubecl/kernel/index/
slice.rs1use crate::{CubeRuntime, element::CubeElement, ops::numeric::empty_device, tensor::CubeTensor};
2use burn_tensor::Shape;
3use cubecl::{calculate_cube_count_elemwise, prelude::*};
4use std::ops::Range;
5
6pub fn slice<R: CubeRuntime, E: CubeElement>(
8 tensor: CubeTensor<R>,
9 indices: &[Range<usize>],
10) -> CubeTensor<R> {
11 let mut dims = tensor.shape.dims.clone();
12 let mut offset_start = 0u64;
13 let mut offset_end = 0u64;
14
15 for i in 0..indices.len() {
16 offset_start += (tensor.strides[i] * indices[i].start) as u64;
17 offset_end += (tensor.strides[i] * (dims[i] - indices[i].end)) as u64;
18 dims[i] = indices[i].end - indices[i].start;
19 }
20
21 let offset_start = offset_start * E::cube_elem().size() as u64;
22 let offset_end = offset_end * E::cube_elem().size() as u64;
23
24 let memory_offset_alignment = tensor.client.properties().memory_properties().alignment;
25
26 if offset_start % memory_offset_alignment == 0u64
27 && offset_end % memory_offset_alignment == 0u64
28 {
29 CubeTensor::new(
30 tensor.client,
31 tensor
32 .handle
33 .offset_start(offset_start)
34 .offset_end(offset_end),
35 Shape::from(dims),
36 tensor.device,
37 tensor.strides,
38 tensor.dtype,
39 )
40 } else {
41 let shape_output = Shape::from(dims);
42 let output =
43 empty_device::<R, E>(tensor.client.clone(), tensor.device.clone(), shape_output);
44 slice_on_output::<R, E>(tensor, output, indices)
45 }
46}
47
48#[cube(launch_unchecked)]
49fn slice_kernel<E: CubePrimitive>(
50 input: &Tensor<E>,
51 output: &mut Tensor<E>,
52 indices: Sequence<u32>,
53 #[comptime] rank: u32,
54) {
55 if ABSOLUTE_POS >= output.len() {
56 terminate!();
57 }
58
59 let mut offset_input = 0;
60
61 #[unroll]
62 for i in 0..rank {
63 let range_start = *indices.index(i);
64 let offset_local = ABSOLUTE_POS / output.stride(i) % output.shape(i) + range_start;
65
66 offset_input += offset_local * input.stride(i);
67 }
68
69 output[ABSOLUTE_POS] = input[offset_input];
70}
71
72pub(crate) fn slice_on_output<R: CubeRuntime, E: CubeElement>(
73 tensor: CubeTensor<R>,
74 output: CubeTensor<R>,
75 indices: &[Range<usize>],
76) -> CubeTensor<R> {
77 let ndims = tensor.shape.num_dims();
78 let mut indices_sequence = SequenceArg::<R, u32>::new();
79
80 for i in 0..ndims {
81 let start = indices.get(i).map(|index| index.start).unwrap_or(0);
82 indices_sequence.push(ScalarArg::new(start as u32));
83 }
84
85 let cube_dim = CubeDim::default();
86 let cube_count = calculate_cube_count_elemwise(output.shape.num_elements(), cube_dim);
87
88 unsafe {
89 slice_kernel::launch_unchecked::<E, R>(
90 &tensor.client,
91 cube_count,
92 cube_dim,
93 tensor.as_tensor_arg::<E>(1),
94 output.as_tensor_arg::<E>(1),
95 indices_sequence,
96 ndims as u32,
97 )
98 };
99
100 output
101}