burn_cubecl/kernel/index/
slice.rs1use crate::{
2 CubeRuntime,
3 kernel::utils::{address_type, shape_divmod},
4 ops::numeric::empty_device_dtype,
5 tensor::CubeTensor,
6};
7use burn_backend::{Slice, TensorMetadata};
8use burn_std::{Metadata, SliceOps};
9use cubecl::{
10 calculate_cube_count_elemwise, intrinsic,
11 prelude::*,
12 std::{FastDivmod, tensor::layout::linear::LinearView},
13};
14use std::ops::Range;
15
16pub fn slice<R: CubeRuntime>(tensor: CubeTensor<R>, indices: &[Range<usize>]) -> CubeTensor<R> {
18 let mut dims = tensor.shape();
19 let mut offset_start = 0u64;
20 let mut offset_end = 0u64;
21
22 for i in 0..indices.len() {
23 offset_start += (tensor.meta.strides()[i] * indices[i].start) as u64;
24 offset_end += (tensor.meta.strides()[i] * (dims[i] - indices[i].end)) as u64;
25 dims[i] = indices[i].end - indices[i].start;
26 }
27
28 let offset_start = offset_start * tensor.dtype.size() as u64;
29 let offset_end = offset_end * tensor.dtype.size() as u64;
30
31 let memory_offset_alignment = tensor.client.properties().memory.alignment;
32
33 if offset_start.is_multiple_of(memory_offset_alignment)
34 && offset_end.is_multiple_of(memory_offset_alignment)
35 {
36 CubeTensor::new(
37 tensor.client.clone(),
38 tensor
39 .handle
40 .clone()
41 .offset_start(offset_start)
42 .offset_end(offset_end),
43 Metadata::new(dims, tensor.meta.strides.clone()),
44 tensor.device.clone(),
45 tensor.dtype,
46 )
47 } else {
48 let output = empty_device_dtype(
49 tensor.client.clone(),
50 tensor.device.clone(),
51 dims,
52 tensor.dtype,
53 );
54 slice_on_output(tensor, output, indices)
55 }
56}
57
58#[cube(launch_unchecked, address_type = "dynamic")]
59fn slice_kernel<E: Numeric>(
60 input: &Tensor<E>,
61 output: &mut LinearView<E, ReadWrite>,
62 out_shape: Sequence<FastDivmod<usize>>,
63 indices: Sequence<usize>,
64 #[define(E)] _dtype: StorageType,
65) {
66 if !output.is_in_bounds(ABSOLUTE_POS) {
67 terminate!();
68 }
69
70 let rank = comptime![out_shape.len()];
71 let mut offset_output = ABSOLUTE_POS;
72 let mut offset_input = 0;
73
74 #[unroll]
75 for i in 0..rank {
76 let dim = rank - i - 1;
78
79 let range_start = indices[dim];
80 let (rem, offset_local) = out_shape[dim].div_mod(offset_output);
81 offset_output = rem;
82
83 let offset_local = offset_local + range_start;
84
85 offset_input += offset_local * input.stride(dim);
86 }
87
88 output[ABSOLUTE_POS] = input[offset_input];
89}
90
91pub(crate) fn slice_on_output<R: CubeRuntime>(
92 tensor: CubeTensor<R>,
93 output: CubeTensor<R>,
94 indices: &[Range<usize>],
95) -> CubeTensor<R> {
96 let ndims = tensor.meta.num_dims();
97 let mut indices_sequence = SequenceArg::<R, usize>::new();
98
99 for i in 0..ndims {
100 let start = indices.get(i).map(|index| index.start).unwrap_or(0);
101 indices_sequence.push(start);
102 }
103
104 let working_units = output.meta.num_elements();
105 let cube_dim = CubeDim::new(&tensor.client, working_units);
106 let cube_count = calculate_cube_count_elemwise(&tensor.client, working_units, cube_dim);
107 let dtype = tensor.dtype;
108
109 unsafe {
110 slice_kernel::launch_unchecked(
111 &output.client,
112 cube_count,
113 cube_dim,
114 address_type!(tensor, output),
115 tensor.into_tensor_arg(),
116 output.clone().into_linear_view(),
117 shape_divmod(&output),
118 indices_sequence,
119 dtype.into(),
120 )
121 };
122
123 output
124}
125
126#[cube(launch_unchecked, address_type = "dynamic")]
128fn slice_with_steps_kernel<E: Numeric>(
129 input: &Tensor<E>,
130 output: &mut LinearView<E, ReadWrite>,
131 out_shape: Sequence<FastDivmod<usize>>,
132 starts: Sequence<usize>,
133 ends: Sequence<usize>,
134 steps: Sequence<i32>,
135 #[define(E)] _dtype: StorageType,
136) {
137 if !output.is_in_bounds(ABSOLUTE_POS) {
138 terminate!();
139 }
140
141 let rank = comptime![out_shape.len()];
142 let mut output_offset = ABSOLUTE_POS;
143 let mut input_offset = 0;
144
145 #[unroll]
147 for i in 0..rank {
148 let dim = rank - i - 1;
150 let start = starts[dim];
151 let end = ends[dim];
152 let step = steps[dim];
153
154 let (rem, output_idx) = out_shape[dim].div_mod(output_offset);
155 output_offset = rem;
156
157 let input_idx = if step > 0 {
158 start + output_idx * (step as usize)
160 } else {
161 let abs_step = (-step) as usize;
163 let end_minus_1 = end - 1;
164 end_minus_1 - output_idx * abs_step
165 };
166
167 input_offset += input_idx * input.stride(dim);
168 }
169
170 output[ABSOLUTE_POS] = input[input_offset];
171}
172
173pub fn slice_with_steps<R: CubeRuntime>(tensor: CubeTensor<R>, slices: &[Slice]) -> CubeTensor<R> {
175 let all_steps_one = slices.iter().all(|info| info.step == 1);
177
178 if all_steps_one {
179 let simple_ranges: Vec<Range<usize>> = slices
181 .iter()
182 .enumerate()
183 .map(|(i, slice)| slice.to_range(tensor.meta.shape()[i]))
184 .collect();
185 return slice(tensor, &simple_ranges);
186 }
187
188 let shape_output = tensor.shape().slice(slices).unwrap();
190
191 let output = empty_device_dtype(
193 tensor.client.clone(),
194 tensor.device.clone(),
195 shape_output.clone(),
196 tensor.dtype,
197 );
198
199 let mut starts = SequenceArg::<R, usize>::new();
201 let mut ends = SequenceArg::<R, usize>::new();
202 let mut steps = SequenceArg::<R, i32>::new();
203
204 for (dim, slice) in slices.iter().enumerate() {
205 let range = slice.to_range(tensor.meta.shape()[dim]);
206 starts.push(range.start);
207 ends.push(range.end);
208 steps.push(slice.step as i32);
209 }
210
211 for dim in slices.len()..tensor.meta.num_dims() {
213 starts.push(0);
214 ends.push(tensor.meta.shape[dim]);
215 steps.push(1);
216 }
217
218 let working_units = shape_output.num_elements();
220 let cube_dim = CubeDim::new(&tensor.client, working_units);
221 let cube_count = calculate_cube_count_elemwise(&tensor.client, working_units, cube_dim);
222 let dtype = tensor.dtype;
223
224 unsafe {
225 slice_with_steps_kernel::launch_unchecked(
226 &output.client,
227 cube_count,
228 cube_dim,
229 address_type!(tensor, output),
230 tensor.into_tensor_arg(),
231 output.clone().into_linear_view(),
232 shape_divmod(&output),
233 starts,
234 ends,
235 steps,
236 dtype.into(),
237 );
238 }
239
240 output
241}
242
243#[allow(unused)]
245#[cube]
246fn unwrap(value: u32) -> comptime_type!(u32) {
247 intrinsic!(|_| value.constant().unwrap().as_u32())
248}