1use crate::{FastDivmod, FastDivmodArgs};
2
3use super::TensorHandle;
4use cubecl::prelude::*;
5use cubecl_core::{self as cubecl, calculate_cube_count_elemwise, tensor_line_size_parallel};
6
7pub const NUM_SM_APPROX: u32 = 50;
8
9#[cube]
11pub fn index_offset_with_layout<N: CubePrimitive, L: CubePrimitive>(
12 tensor: &Tensor<Line<N>>,
13 layout: &Tensor<Line<L>>,
14 offset_layout: u32,
15 dim_start: u32,
16 dim_end: u32,
17 #[comptime] unroll: bool,
18) -> u32 {
19 let offset_ref = offset_layout * tensor.line_size();
20 let mut offset = 0;
21
22 #[unroll(unroll)]
23 for i in dim_start..dim_end {
24 let ogwl = offset_ref / layout.stride(i);
25 offset += ogwl % tensor.shape(i) * tensor.stride(i);
26 }
27
28 offset / tensor.line_size()
29}
30
31#[cube]
33pub fn index_offset_contiguous<N: CubePrimitive>(
34 tensor: &Tensor<Line<N>>,
35 offset_layout: u32,
36 #[comptime] rank: Option<u32>,
37) -> u32 {
38 let unroll = rank.is_some();
39 let rank = rank.unwrap_or_else(|| tensor.rank());
40
41 let offset_ref = offset_layout * tensor.line_size();
42 let mut offset = 0;
43 let mut remainder = offset_ref;
44
45 #[unroll(unroll)]
46 for i in 0..rank {
47 let dim = rank - i - 1;
48 let shape = tensor.shape(dim);
49 let ogwl = remainder % shape;
50 offset += ogwl * tensor.stride(dim);
51 remainder /= shape;
52 }
53
54 offset / tensor.line_size()
55}
56
57#[cube]
59pub fn index_offset_contiguous_fastdivmod<N: CubePrimitive>(
60 tensor: &Tensor<Line<N>>,
61 offset_layout: u32,
62 shape: &Sequence<FastDivmod>,
63 stride: &Sequence<u32>,
64) -> u32 {
65 let rank = comptime![shape.len()];
66
67 let offset_ref = offset_layout * tensor.line_size();
68 let mut offset = 0;
69 let mut remainder = offset_ref;
70
71 let mut dim = comptime![rank - 1];
72
73 #[unroll]
74 for _ in 0..rank {
75 let shape = shape.index(dim);
76 let (rem, ogwl) = shape.div_mod(remainder);
77 offset += ogwl * stride.index(dim);
78 remainder = rem;
79
80 comptime![dim = dim.saturating_sub(1);]
81 }
82
83 offset / tensor.line_size()
84}
85
86#[derive(CubeType, CubeLaunch)]
89pub enum StridedLayout {
90 Pitched(FastDivmod),
91 None,
92}
93
94impl<R: Runtime> StridedLayoutArgs<'_, R> {
95 pub fn none() -> Self {
97 Self::None
98 }
99
100 pub fn strided(client: &ComputeClient<R::Server, R::Channel>, shape: u32) -> Self {
102 Self::Pitched(FastDivmodArgs::new(client, shape))
103 }
104}
105
106#[cube]
107impl StridedLayout {
108 pub fn index<T: CubePrimitive>(&self, tensor: &Tensor<Line<T>>, index: u32) -> u32 {
110 match self {
111 StridedLayout::Pitched(divmod) => {
112 let offset_abs = index * tensor.line_size();
113 let (y, x) = divmod.div_mod(offset_abs);
114 let offset = y * tensor.stride(tensor.rank() - 2) + x;
115 offset / tensor.line_size()
116 }
117 StridedLayout::None => index,
118 }
119 }
120}
121
122#[cube(launch)]
123fn into_contiguous_kernel<N: CubePrimitive>(
124 input: &Tensor<Line<N>>,
125 output: &mut Tensor<Line<N>>,
126 out_layout: StridedLayout,
127 shape: Sequence<FastDivmod>,
128 stride: Sequence<u32>,
129 #[comptime] elems_per_thread: u32,
130) {
131 let offset_output = ABSOLUTE_POS * elems_per_thread;
132 let line_size = input.line_size();
133
134 let mut registers = Array::<Line<N>>::vectorized(elems_per_thread, line_size);
135
136 #[unroll]
137 for i in 0..elems_per_thread {
138 let offset_input =
139 index_offset_contiguous_fastdivmod::<N>(input, offset_output + i, &shape, &stride);
140 registers[i] = input[offset_input];
141 }
142
143 let offset_output = out_layout.index(output, offset_output);
144
145 #[unroll]
146 for i in 0..elems_per_thread {
147 output[offset_output + i] = registers[i];
148 }
149}
150
151#[cube(launch)]
152fn into_contiguous_kernel_pack<N: CubePrimitive>(
153 input: &Tensor<Line<N>>,
154 output: &mut Tensor<Line<N>>,
155 out_layout: StridedLayout,
156 shape: Sequence<FastDivmod>,
157 stride: Sequence<u32>,
158 #[comptime] elems_per_thread: u32,
159) {
160 let line_size = output.line_size();
161 let lines_per_thread = comptime![elems_per_thread / line_size];
162
163 let offset_output = ABSOLUTE_POS * lines_per_thread;
164 let offset_input = offset_output * line_size;
165
166 let mut registers = Array::<Line<N>>::vectorized(lines_per_thread, line_size);
167
168 #[unroll]
169 for i in 0..lines_per_thread {
170 let offset = i * line_size;
171 let mut reg = Line::<N>::empty(line_size);
172 #[unroll]
173 for k in 0..line_size {
174 let offset_input = offset_input + offset + k;
175 let offset_input =
176 index_offset_contiguous_fastdivmod::<N>(input, offset_input, &shape, &stride);
177 reg[k] = input[offset_input][0];
178 }
179 registers[i] = reg;
180 }
181
182 let offset_output = out_layout.index(output, offset_output);
183
184 #[unroll]
185 for i in 0..lines_per_thread {
186 output[offset_output + i] = registers[i];
187 }
188}
189
190pub fn into_contiguous<R: Runtime, E: CubePrimitive>(
192 client: &ComputeClient<R::Server, R::Channel>,
193 input: &TensorHandleRef<'_, R>,
194) -> TensorHandle<R, E> {
195 let num_elems: usize = input.shape.iter().product();
196
197 let handle = client.empty(num_elems * size_of::<E>());
198 let output = TensorHandle::new_contiguous(input.shape.to_vec(), handle);
199
200 into_contiguous_ref::<R, E>(client, input, &output.as_ref());
201
202 output
203}
204
205pub fn into_contiguous_pitched<R: Runtime, E: CubePrimitive>(
208 client: &ComputeClient<R::Server, R::Channel>,
209 input: &TensorHandleRef<'_, R>,
210) -> TensorHandle<R, E> {
211 if input.shape.len() <= 1 {
212 return into_contiguous(client, input);
213 }
214
215 let output = TensorHandle::empty(client, input.shape.to_vec());
216
217 into_contiguous_ref::<R, E>(client, input, &output.as_ref());
218
219 output
220}
221
222pub fn into_contiguous_ref<R: Runtime, E: CubePrimitive>(
224 client: &ComputeClient<R::Server, R::Channel>,
225 input: &TensorHandleRef<'_, R>,
226 output: &TensorHandleRef<'_, R>,
227) {
228 let num_elems: usize = input.shape.iter().product();
229
230 let rank = input.strides.len();
232 let vectorization_factor = tensor_line_size_parallel(
233 R::supported_line_sizes().iter().cloned(),
234 input.shape,
235 input.strides,
236 rank - 1,
237 );
238 let num_vecs = num_elems / vectorization_factor as usize;
239 let num_sm = client
240 .properties()
241 .hardware
242 .num_streaming_multiprocessors
243 .unwrap_or(NUM_SM_APPROX);
244 let simul_vecs = num_sm * CubeDim::default().num_elems();
245 let mut elems_per_unit = match num_vecs as u32 / simul_vecs {
246 0..2 => 1,
247 2..4 => 2,
248 4..8 => 4,
249 8.. => 8,
250 };
251
252 let mut num_elems_per_unit = vectorization_factor as u32 * elems_per_unit;
253
254 let last_dim = output.shape[rank - 1];
255 let is_padded = rank > 1 && last_dim != output.strides[rank - 2];
256
257 while is_padded && last_dim % num_elems_per_unit as usize != 0 {
259 elems_per_unit /= 2;
260 num_elems_per_unit /= 2;
261 }
262
263 let out_layout = match is_padded {
264 true => StridedLayoutArgs::strided(client, last_dim as u32),
265 false => StridedLayoutArgs::none(),
266 };
267
268 let out_vec = if vectorization_factor > 1 {
269 vectorization_factor
270 } else {
271 *R::supported_line_sizes()
272 .iter()
273 .filter(|it| num_elems_per_unit % **it as u32 == 0)
274 .max()
275 .unwrap_or(&1)
276 };
277
278 let cube_dim = CubeDim::default();
279 let cube_count =
280 calculate_cube_count_elemwise(num_elems.div_ceil(num_elems_per_unit as usize), cube_dim);
281
282 let shape = SequenceArg {
283 values: input
284 .shape
285 .iter()
286 .map(|dim| FastDivmodArgs::new(client, *dim as u32))
287 .collect(),
288 };
289
290 let stride = SequenceArg {
291 values: input
292 .strides
293 .iter()
294 .map(|s| ScalarArg::new(*s as u32))
295 .collect(),
296 };
297
298 let launch = if vectorization_factor != out_vec && out_vec > 1 {
299 into_contiguous_kernel_pack::launch::<E, R>
300 } else {
301 into_contiguous_kernel::launch::<E, R>
302 };
303
304 launch(
305 client,
306 cube_count,
307 cube_dim,
308 input.as_tensor_arg(vectorization_factor),
309 output.as_tensor_arg(out_vec),
310 out_layout,
311 shape,
312 stride,
313 elems_per_unit,
314 );
315}
316
317pub fn is_contiguous(shape: &[usize], strides: &[usize]) -> bool {
319 if shape.is_empty() {
320 return true;
321 }
322
323 for (expected, &stride) in compact_strides(shape).into_iter().zip(strides) {
324 if expected != stride {
325 return false;
326 }
327 }
328
329 true
330}
331
332pub fn compact_strides(shape: &[usize]) -> Vec<usize> {
333 let rank = shape.len();
334 let mut strides = vec![1; rank];
335 for i in (0..rank - 1).rev() {
336 strides[i] = strides[i + 1] * shape[i + 1];
337 }
338 strides
339}