burn_cubecl/kernel/index/
slice_assign.rs1use crate::{
2 CubeRuntime,
3 kernel::utils::{linear_view, shape_divmod},
4 tensor::CubeTensor,
5};
6use cubecl::{
7 calculate_cube_count_elemwise, intrinsic,
8 prelude::*,
9 std::{FastDivmod, FastDivmodArgs, tensor::layout::linear::LinearView},
10};
11
12#[cube(launch_unchecked)]
13fn slice_assign_kernel<E: Numeric>(
14 input: &mut Tensor<Line<E>>,
15 value: &LinearView<Line<E>>,
16 slice_shape: Sequence<FastDivmod<usize>>,
17 slice_offsets: Sequence<usize>,
18 #[define(E)] _dtype: StorageType,
19) {
20 if !value.is_in_bounds(ABSOLUTE_POS) {
21 terminate!()
22 }
23
24 let rank = comptime!(slice_shape.len());
25
26 let line_size = input.line_size();
27 let mut offset_remainder = ABSOLUTE_POS * line_size;
28 let mut offset_input = 0;
29
30 #[allow(clippy::explicit_counter_loop)]
31 #[unroll]
32 for i in 0..rank {
33 let dim = rank - i - 1;
34 let (rem, offset_local) = slice_shape[dim].div_mod(offset_remainder);
35
36 let range_start = slice_offsets[dim];
37 let offset_local_input = offset_local + range_start;
38
39 offset_input += offset_local_input * input.stride(dim);
40 offset_remainder = rem;
41 }
42
43 input[offset_input / line_size] = value[ABSOLUTE_POS];
45}
46
47#[cube(launch_unchecked)]
49fn slice_assign_with_steps_kernel<E: Numeric>(
50 input: &mut Tensor<E>,
51 value: &LinearView<E>,
52 value_shape: Sequence<FastDivmod<usize>>,
53 starts: Sequence<usize>,
54 ends: Sequence<usize>,
55 steps: Sequence<i32>,
56 #[define(E)] _dtype: StorageType,
57) {
58 if !value.is_in_bounds(ABSOLUTE_POS) {
59 terminate!();
60 }
61
62 let rank = comptime![value_shape.len()];
63 let mut value_offset = ABSOLUTE_POS;
64 let mut input_offset = 0;
65
66 #[unroll]
68 for i in 0..rank {
69 let dim = rank - i - 1;
71 let start = starts[dim];
72 let end = ends[dim];
73 let step = steps[dim];
74
75 let (rem, value_idx) = value_shape[dim].div_mod(value_offset);
76 value_offset = rem;
77
78 let input_idx = if step > 0 {
79 start + value_idx * (step as usize)
81 } else if step < 0 {
82 let abs_step = (-step) as usize;
85 let end_minus_1 = end - 1;
86 end_minus_1 - value_idx * abs_step
87 } else {
88 value_idx
90 };
91
92 input_offset += input_idx * input.stride(dim);
93 }
94
95 input[input_offset] = value[ABSOLUTE_POS];
96}
97
98pub(crate) fn slice_assign<R: CubeRuntime>(
99 tensor: CubeTensor<R>,
100 indices: &[burn_backend::Slice],
101 value: CubeTensor<R>,
102) -> CubeTensor<R> {
103 let has_non_unit_step = indices.iter().any(|s| s.step != 1 && s.step != 0);
105
106 if has_non_unit_step {
107 return slice_assign_with_steps(tensor, indices, value);
109 }
110
111 let client = tensor.client.clone();
112 let tensor = match tensor.can_mut() {
113 true => tensor,
114 false => tensor.copy(),
115 };
116 let ndims = tensor.shape.num_dims();
117
118 let line_size = if tensor.strides[ndims - 1] == 1 && value.strides[ndims - 1] == 1 {
119 let last = indices
120 .get(ndims - 1)
121 .cloned()
122 .unwrap_or(burn_backend::Slice {
123 start: 0,
124 end: Some(tensor.shape[ndims - 1] as isize),
125 step: 1,
126 });
127 let end = last.end.unwrap_or(tensor.shape[ndims - 1] as isize);
128 let shape = (end - last.start) as usize;
129 let offset = last.start as usize;
130 *R::supported_line_sizes()
131 .iter()
132 .filter(|it| {
133 let it = **it;
134 shape.is_multiple_of(it)
135 && strides_compatible(&tensor.strides, it)
136 && strides_compatible(&value.strides, it)
137 && offset.is_multiple_of(it)
138 })
139 .max()
140 .unwrap_or(&1)
141 } else {
142 1
143 };
144
145 let mut shape = SequenceArg::<R, FastDivmod<usize>>::new();
146 let mut offsets = SequenceArg::<R, usize>::new();
147
148 for i in 0..ndims {
149 let slice = indices.get(i).cloned().unwrap_or(burn_backend::Slice {
150 start: 0,
151 end: Some(tensor.shape[i] as isize),
152 step: 1,
153 });
154 let start = slice.start as usize;
155 let end = slice.end.unwrap_or(tensor.shape[i] as isize);
156 let length = (end - slice.start) as usize;
157
158 shape.push(FastDivmodArgs::<usize>::new(&client, length));
159 offsets.push(ScalarArg::new(start));
160 }
161
162 let working_units = value.shape.num_elements() / line_size;
163 let cube_dim = CubeDim::new(&tensor.client, working_units);
164 let cube_count = calculate_cube_count_elemwise(&tensor.client, working_units, cube_dim);
165
166 unsafe {
167 slice_assign_kernel::launch_unchecked(
168 &tensor.client,
169 cube_count,
170 cube_dim,
171 tensor.as_tensor_arg(line_size),
172 linear_view(&value, line_size),
173 shape,
174 offsets,
175 tensor.dtype.into(),
176 )
177 .expect("Kernel to never fail");
178 }
179
180 tensor
181}
182
183pub(crate) fn slice_assign_with_steps<R: CubeRuntime>(
193 tensor: CubeTensor<R>,
194 slices: &[burn_backend::Slice],
195 value: CubeTensor<R>,
196) -> CubeTensor<R> {
197 let tensor = match tensor.can_mut() {
198 true => tensor,
199 false => tensor.copy(),
200 };
201
202 let mut starts = SequenceArg::<R, usize>::new();
204 let mut ends = SequenceArg::<R, usize>::new();
205 let mut steps = SequenceArg::<R, i32>::new();
206
207 for (dim, slice) in slices.iter().enumerate() {
208 let range = slice.to_range(tensor.shape[dim]);
209 starts.push(ScalarArg::new(range.start));
210 ends.push(ScalarArg::new(range.end));
211 steps.push(ScalarArg::new(slice.step as i32));
212 }
213
214 for dim in slices.len()..tensor.shape.num_dims() {
216 starts.push(ScalarArg::new(0));
217 ends.push(ScalarArg::new(tensor.shape[dim]));
218 steps.push(ScalarArg::new(1));
219 }
220
221 let working_units = value.shape.num_elements();
223 let cube_dim = CubeDim::new(&tensor.client, working_units);
224 let cube_count = calculate_cube_count_elemwise(&tensor.client, working_units, cube_dim);
225
226 unsafe {
227 slice_assign_with_steps_kernel::launch_unchecked(
228 &tensor.client,
229 cube_count,
230 cube_dim,
231 tensor.as_tensor_arg(1),
232 linear_view(&value, 1),
233 shape_divmod(&value),
234 starts,
235 ends,
236 steps,
237 tensor.dtype.into(),
238 )
239 .expect("Kernel to never fail");
240 }
241
242 tensor
243}
244
245fn strides_compatible(strides: &[usize], vec: usize) -> bool {
246 strides
247 .iter()
248 .all(|stride| *stride % vec == 0 || *stride == 1)
249}
250
251#[allow(unused)]
253#[cube]
254fn unwrap(value: u32) -> comptime_type!(u32) {
255 intrinsic!(|_| value.constant().unwrap().as_u32())
256}