burn_cubecl/kernel/index/
slice.rs1use crate::{
2 CubeRuntime,
3 kernel::utils::{linear_view, shape_divmod},
4 ops::numeric::empty_device_dtype,
5 tensor::CubeTensor,
6};
7use burn_backend::Slice;
8use cubecl::{
9 calculate_cube_count_elemwise, intrinsic,
10 prelude::*,
11 std::{FastDivmod, tensor::layout::linear::LinearView},
12};
13use std::ops::Range;
14
15pub fn slice<R: CubeRuntime>(tensor: CubeTensor<R>, indices: &[Range<usize>]) -> CubeTensor<R> {
17 let mut dims = tensor.shape.clone();
18 let mut offset_start = 0u64;
19 let mut offset_end = 0u64;
20
21 for i in 0..indices.len() {
22 offset_start += (tensor.strides[i] * indices[i].start) as u64;
23 offset_end += (tensor.strides[i] * (dims[i] - indices[i].end)) as u64;
24 dims[i] = indices[i].end - indices[i].start;
25 }
26
27 let offset_start = offset_start * tensor.dtype.size() as u64;
28 let offset_end = offset_end * tensor.dtype.size() as u64;
29
30 let memory_offset_alignment = tensor.client.properties().memory.alignment;
31
32 if offset_start.is_multiple_of(memory_offset_alignment)
33 && offset_end.is_multiple_of(memory_offset_alignment)
34 {
35 CubeTensor::new(
36 tensor.client,
37 tensor
38 .handle
39 .offset_start(offset_start)
40 .offset_end(offset_end),
41 dims,
42 tensor.device,
43 tensor.strides,
44 tensor.dtype,
45 )
46 } else {
47 let output = empty_device_dtype(
48 tensor.client.clone(),
49 tensor.device.clone(),
50 dims,
51 tensor.dtype,
52 );
53 slice_on_output(tensor, output, indices)
54 }
55}
56
57#[cube(launch_unchecked)]
58fn slice_kernel<E: Numeric>(
59 input: &Tensor<E>,
60 output: &mut LinearView<E, ReadWrite>,
61 out_shape: Sequence<FastDivmod<usize>>,
62 indices: Sequence<usize>,
63 #[define(E)] _dtype: StorageType,
64) {
65 if !output.is_in_bounds(ABSOLUTE_POS) {
66 terminate!();
67 }
68
69 let rank = comptime![out_shape.len()];
70 let mut offset_output = ABSOLUTE_POS;
71 let mut offset_input = 0;
72
73 #[unroll]
74 for i in 0..rank {
75 let dim = rank - i - 1;
77
78 let range_start = indices[dim];
79 let (rem, offset_local) = out_shape[dim].div_mod(offset_output);
80 offset_output = rem;
81
82 let offset_local = offset_local + range_start;
83
84 offset_input += offset_local * input.stride(dim);
85 }
86
87 output[ABSOLUTE_POS] = input[offset_input];
88}
89
90pub(crate) fn slice_on_output<R: CubeRuntime>(
91 tensor: CubeTensor<R>,
92 output: CubeTensor<R>,
93 indices: &[Range<usize>],
94) -> CubeTensor<R> {
95 let ndims = tensor.shape.num_dims();
96 let mut indices_sequence = SequenceArg::<R, usize>::new();
97
98 for i in 0..ndims {
99 let start = indices.get(i).map(|index| index.start).unwrap_or(0);
100 indices_sequence.push(ScalarArg::new(start));
101 }
102
103 let working_units = output.shape.num_elements();
104 let cube_dim = CubeDim::new(&tensor.client, working_units);
105 let cube_count = calculate_cube_count_elemwise(&tensor.client, working_units, cube_dim);
106
107 unsafe {
108 slice_kernel::launch_unchecked(
109 &tensor.client,
110 cube_count,
111 cube_dim,
112 tensor.as_tensor_arg(1),
113 linear_view(&output, 1),
114 shape_divmod(&output),
115 indices_sequence,
116 tensor.dtype.into(),
117 )
118 .expect("Kernel to never fail");
119 };
120
121 output
122}
123
124#[cube(launch_unchecked)]
126fn slice_with_steps_kernel<E: Numeric>(
127 input: &Tensor<E>,
128 output: &mut LinearView<E, ReadWrite>,
129 out_shape: Sequence<FastDivmod<usize>>,
130 starts: Sequence<usize>,
131 ends: Sequence<usize>,
132 steps: Sequence<i32>,
133 #[define(E)] _dtype: StorageType,
134) {
135 if !output.is_in_bounds(ABSOLUTE_POS) {
136 terminate!();
137 }
138
139 let rank = comptime![out_shape.len()];
140 let mut output_offset = ABSOLUTE_POS;
141 let mut input_offset = 0;
142
143 #[unroll]
145 for i in 0..rank {
146 let dim = rank - i - 1;
148 let start = starts[dim];
149 let end = ends[dim];
150 let step = steps[dim];
151
152 let (rem, output_idx) = out_shape[dim].div_mod(output_offset);
153 output_offset = rem;
154
155 let input_idx = if step > 0 {
156 start + output_idx * (step as usize)
158 } else {
159 let abs_step = (-step) as usize;
161 let end_minus_1 = end - 1;
162 end_minus_1 - output_idx * abs_step
163 };
164
165 input_offset += input_idx * input.stride(dim);
166 }
167
168 output[ABSOLUTE_POS] = input[input_offset];
169}
170
171pub fn slice_with_steps<R: CubeRuntime>(tensor: CubeTensor<R>, slices: &[Slice]) -> CubeTensor<R> {
173 let all_steps_one = slices.iter().all(|info| info.step == 1);
175
176 if all_steps_one {
177 let simple_ranges: Vec<Range<usize>> = slices
179 .iter()
180 .enumerate()
181 .map(|(i, slice)| slice.to_range(tensor.shape[i]))
182 .collect();
183 return slice(tensor, &simple_ranges);
184 }
185
186 let shape_output = tensor.shape.clone().slice(slices).unwrap();
188
189 let output = empty_device_dtype(
191 tensor.client.clone(),
192 tensor.device.clone(),
193 shape_output.clone(),
194 tensor.dtype,
195 );
196
197 let mut starts = SequenceArg::<R, usize>::new();
199 let mut ends = SequenceArg::<R, usize>::new();
200 let mut steps = SequenceArg::<R, i32>::new();
201
202 for (dim, slice) in slices.iter().enumerate() {
203 let range = slice.to_range(tensor.shape[dim]);
204 starts.push(ScalarArg::new(range.start));
205 ends.push(ScalarArg::new(range.end));
206 steps.push(ScalarArg::new(slice.step as i32));
207 }
208
209 for dim in slices.len()..tensor.shape.num_dims() {
211 starts.push(ScalarArg::new(0));
212 ends.push(ScalarArg::new(tensor.shape[dim]));
213 steps.push(ScalarArg::new(1));
214 }
215
216 let working_units = shape_output.num_elements();
218 let cube_dim = CubeDim::new(&tensor.client, working_units);
219 let cube_count = calculate_cube_count_elemwise(&tensor.client, working_units, cube_dim);
220
221 unsafe {
222 slice_with_steps_kernel::launch_unchecked(
223 &tensor.client,
224 cube_count,
225 cube_dim,
226 tensor.as_tensor_arg(1),
227 linear_view(&output, 1),
228 shape_divmod(&output),
229 starts,
230 ends,
231 steps,
232 tensor.dtype.into(),
233 )
234 .expect("Kernel to never fail");
235 }
236
237 output
238}
239
240#[allow(unused)]
242#[cube]
243fn unwrap(value: u32) -> comptime_type!(u32) {
244 intrinsic!(|_| value.constant().unwrap().as_u32())
245}