1use crate::{
2 FastDivmod, FastDivmodArgs,
3 tensor::layout::{
4 Layout, LayoutExpand,
5 linear::{LinearLayout, LinearLayoutArgs, LinearView, linear_view},
6 },
7};
8
9use super::TensorHandle;
10use cubecl::prelude::*;
11use cubecl_core::{self as cubecl, calculate_cube_count_elemwise, tensor_line_size_parallel};
12
13pub const NUM_SM_APPROX: u32 = 50;
14
15#[cube]
17pub fn index_offset_with_layout<N: CubePrimitive, L: CubePrimitive>(
18 tensor: &Tensor<Line<N>>,
19 layout: &Tensor<Line<L>>,
20 offset_layout: u32,
21 dim_start: u32,
22 dim_end: u32,
23 #[comptime] unroll: bool,
24) -> u32 {
25 let offset_ref = offset_layout * tensor.line_size();
26 let mut offset = 0;
27
28 #[unroll(unroll)]
29 for i in dim_start..dim_end {
30 let ogwl = offset_ref / layout.stride(i);
31 offset += ogwl % tensor.shape(i) * tensor.stride(i);
32 }
33
34 offset / tensor.line_size()
35}
36
37#[cube]
39pub fn index_offset_contiguous<N: CubePrimitive>(
40 tensor: &Tensor<Line<N>>,
41 offset_layout: u32,
42 #[comptime] rank: Option<u32>,
43) -> u32 {
44 let unroll = rank.is_some();
45 let rank = rank.unwrap_or_else(|| tensor.rank());
46
47 let offset_ref = offset_layout * tensor.line_size();
48 let mut offset = 0;
49 let mut remainder = offset_ref;
50
51 #[unroll(unroll)]
52 for i in 0..rank {
53 let dim = rank - i - 1;
54 let shape = tensor.shape(dim);
55 let ogwl = remainder % shape;
56 offset += ogwl * tensor.stride(dim);
57 remainder /= shape;
58 }
59
60 offset / tensor.line_size()
61}
62
63#[cube]
65pub fn index_offset_contiguous_fastdivmod(
66 offset: u32,
67 shape: &Sequence<FastDivmod>,
68 stride: &Sequence<u32>,
69 #[comptime] line_size: u32,
70) -> u32 {
71 let rank = comptime![shape.len()];
72
73 let offset_ref = offset * line_size;
74 let mut offset = 0;
75 let mut remainder = offset_ref;
76
77 let mut dim = comptime![rank - 1];
78
79 #[unroll]
80 for _ in 0..rank {
81 let shape = shape.index(dim);
82 let (rem, ogwl) = shape.div_mod(remainder);
83 offset += ogwl * stride.index(dim);
84 remainder = rem;
85
86 comptime![dim = dim.saturating_sub(1);]
87 }
88
89 offset / line_size
90}
91
92#[cube(launch)]
93fn into_contiguous_kernel<N: CubePrimitive>(
94 input: &LinearView<Line<N>>,
95 output: &mut Tensor<Line<N>>,
96 out_layout: LinearLayout,
97 #[comptime] elems_per_thread: u32,
98) {
99 let offset_output = ABSOLUTE_POS * elems_per_thread;
100 let line_size = input.line_size();
101
102 let mut registers = Array::<Line<N>>::vectorized(elems_per_thread, line_size);
103
104 #[unroll]
105 for i in 0..elems_per_thread {
106 registers[i] = input[offset_output + i];
107 }
108
109 let offset_output = out_layout.to_source_pos(offset_output);
110
111 #[unroll]
112 for i in 0..elems_per_thread {
113 output[offset_output + i] = registers[i];
114 }
115}
116
117#[cube(launch)]
118fn into_contiguous_kernel_pack<N: CubePrimitive>(
119 input: &LinearView<Line<N>>,
120 output: &mut Tensor<Line<N>>,
121 out_layout: LinearLayout,
122 #[comptime] elems_per_thread: u32,
123) {
124 let line_size = output.line_size();
125 let lines_per_thread = comptime![elems_per_thread / line_size];
126
127 let offset_output = ABSOLUTE_POS * lines_per_thread;
128 let offset_input = offset_output * line_size;
129
130 let mut registers = Array::<Line<N>>::vectorized(lines_per_thread, line_size);
131
132 #[unroll]
133 for i in 0..lines_per_thread {
134 let offset = i * line_size;
135 let mut reg = Line::<N>::empty(line_size);
136 #[unroll]
137 for k in 0..line_size {
138 let offset_input = offset_input + offset + k;
139 reg[k] = input[offset_input][0];
140 }
141 registers[i] = reg;
142 }
143
144 let offset_output = out_layout.to_source_pos(offset_output);
145
146 #[unroll]
147 for i in 0..lines_per_thread {
148 output[offset_output + i] = registers[i];
149 }
150}
151
152#[cube]
155fn index_packed<N: Int>(
156 tensor: &Tensor<N>,
157 pos: u32,
158 in_shape: &Sequence<FastDivmod>,
159 #[comptime] packed_dim: u32,
160 #[comptime] packing: u32,
161 #[comptime] rank: u32,
162) -> N {
163 let bits_per_elem = comptime![N::elem_size_bits() / packing];
164 let mask = comptime![(1u32 << bits_per_elem) - 1];
165 let mask = N::cast_from(mask);
166
167 let elem_pos = pos * packing;
168
169 let mut out = N::new(0);
170 for n in 0..packing {
171 let mut remainder = elem_pos + n;
172 let mut offset = 0;
173 let mut packing_offset = 0;
174
175 #[unroll]
176 for i in 0..rank {
177 let dim = comptime![rank - i - 1];
178 let (rem, mut local_pos) = in_shape.index(dim).div_mod(remainder);
179 remainder = rem;
180 if comptime![dim == packed_dim] {
181 packing_offset = local_pos % packing;
182 local_pos /= packing;
183 }
184 offset += local_pos * tensor.stride(dim);
185 }
186 let packed_val = tensor[offset];
187 let shift_in = packing_offset * bits_per_elem;
188 let shift_out = n * bits_per_elem;
189 let value = (packed_val >> N::cast_from(shift_in)) & mask;
190
191 out |= value << N::cast_from(shift_out);
192 }
193 out
194}
195
196#[cube(launch)]
197fn into_contiguous_kernel_packed<N: Int>(
198 input: &Tensor<N>,
199 output: &mut Tensor<Line<N>>,
200 out_layout: LinearLayout,
201 in_shape: Sequence<FastDivmod>,
202 #[comptime] packed_dim: u32,
203 #[comptime] packing: u32,
204 #[comptime] rank: u32,
205 #[comptime] elems_per_thread: u32,
206) {
207 let line_size = output.line_size();
208 let lines_per_thread = comptime![elems_per_thread / line_size];
209
210 let offset_output = ABSOLUTE_POS * lines_per_thread;
211 let offset_input = offset_output * line_size;
212
213 if offset_output >= output.len() {
214 terminate!()
215 }
216
217 let mut registers = Array::<Line<N>>::vectorized(lines_per_thread, line_size);
218
219 #[unroll]
220 for i in 0..lines_per_thread {
221 let offset = i * line_size;
222 let mut reg = Line::<N>::empty(line_size);
223 #[unroll]
224 for k in 0..line_size {
225 let offset_input = offset_input + offset + k;
226
227 reg[k] = index_packed(input, offset_input, &in_shape, packed_dim, packing, rank);
228 }
229 registers[i] = reg;
230 }
231
232 let offset_output = out_layout.to_source_pos(offset_output);
233
234 #[unroll]
235 for i in 0..lines_per_thread {
236 output[offset_output + i] = registers[i];
237 }
238}
239
240pub fn into_contiguous<R: Runtime, E: CubePrimitive>(
242 client: &ComputeClient<R::Server>,
243 input: &TensorHandleRef<'_, R>,
244) -> TensorHandle<R, E> {
245 let num_elems: usize = input.shape.iter().product();
246
247 let handle = client.empty(num_elems * size_of::<E>());
248 let output = TensorHandle::new_contiguous(input.shape.to_vec(), handle);
249
250 into_contiguous_ref::<R, E>(client, input, &output.as_ref());
251
252 output
253}
254
255pub fn into_contiguous_pitched<R: Runtime, E: CubePrimitive>(
258 client: &ComputeClient<R::Server>,
259 input: &TensorHandleRef<'_, R>,
260) -> TensorHandle<R, E> {
261 if input.shape.len() <= 1 {
262 return into_contiguous(client, input);
263 }
264
265 let output = TensorHandle::empty(client, input.shape.to_vec());
266
267 into_contiguous_ref::<R, E>(client, input, &output.as_ref());
268
269 output
270}
271
272pub fn into_contiguous_packed<R: Runtime, I: Int>(
281 client: &ComputeClient<R::Server>,
282 input: &TensorHandleRef<'_, R>,
283 shape: &[usize],
284 packing: u32,
285) -> TensorHandle<R, I> {
286 let rank = shape.len();
287 if rank <= 1 {
288 return into_contiguous(client, input);
289 }
290
291 let mut out_shape = shape.to_vec();
292 out_shape[rank - 1] = out_shape[rank - 1].div_ceil(packing as usize);
293 let output = TensorHandle::<R, I>::empty(client, out_shape);
294
295 into_contiguous_packed_ref::<R, I>(client, input, &output.as_ref(), shape, packing);
298
299 output
300}
301
302pub fn into_contiguous_ref<R: Runtime, E: CubePrimitive>(
304 client: &ComputeClient<R::Server>,
305 input: &TensorHandleRef<'_, R>,
306 output: &TensorHandleRef<'_, R>,
307) {
308 let num_elems: usize = input.shape.iter().product();
309
310 let rank = input.strides.len();
312 let line_size = tensor_line_size_parallel(
313 R::supported_line_sizes().iter().cloned(),
314 input.shape,
315 input.strides,
316 rank - 1,
317 );
318 let num_vecs = num_elems / line_size as usize;
319 let num_sm = client
320 .properties()
321 .hardware
322 .num_streaming_multiprocessors
323 .unwrap_or(NUM_SM_APPROX);
324 let simul_vecs = num_sm * CubeDim::default().num_elems();
325 let mut elems_per_unit = match num_vecs as u32 / simul_vecs {
326 0..2 => 1,
327 2..4 => 2,
328 4..8 => 4,
329 8.. => 8,
330 };
331
332 let mut num_elems_per_unit = line_size as u32 * elems_per_unit;
333
334 let last_dim = output.shape[rank - 1];
335 let is_padded = rank > 1 && last_dim != output.strides[rank - 2];
336
337 while is_padded && !last_dim.is_multiple_of(num_elems_per_unit as usize) {
339 elems_per_unit /= 2;
340 num_elems_per_unit /= 2;
341 }
342
343 let out_vec = if line_size > 1 {
344 line_size
345 } else {
346 *R::supported_line_sizes()
347 .iter()
348 .filter(|it| num_elems_per_unit.is_multiple_of(**it as u32))
349 .max()
350 .unwrap_or(&1)
351 };
352
353 let input = linear_view(client, input, line_size);
354 let out_layout = LinearLayoutArgs::from_handle(client, output, out_vec);
355
356 let cube_dim = CubeDim::default();
357 let cube_count =
358 calculate_cube_count_elemwise(num_elems.div_ceil(num_elems_per_unit as usize), cube_dim);
359
360 let launch = if line_size != out_vec && out_vec > 1 {
361 into_contiguous_kernel_pack::launch::<E, R>
362 } else {
363 into_contiguous_kernel::launch::<E, R>
364 };
365
366 launch(
367 client,
368 cube_count,
369 cube_dim,
370 input,
371 output.as_tensor_arg(out_vec),
372 out_layout,
373 elems_per_unit,
374 );
375}
376
377pub fn into_contiguous_packed_ref<R: Runtime, E: Int>(
379 client: &ComputeClient<R::Server>,
380 input: &TensorHandleRef<'_, R>,
381 output: &TensorHandleRef<'_, R>,
382 shape: &[usize],
383 packing: u32,
384) {
385 let num_elems: usize = input.shape.iter().product();
386
387 let rank = input.strides.len();
389 let line_size = tensor_line_size_parallel(
390 R::io_optimized_line_sizes(&E::as_type_native_unchecked()),
391 output.shape,
392 output.strides,
393 rank - 1,
394 );
395 let num_vecs = num_elems / line_size as usize;
396 let num_sm = client
397 .properties()
398 .hardware
399 .num_streaming_multiprocessors
400 .unwrap_or(NUM_SM_APPROX);
401 let simul_vecs = num_sm * CubeDim::default().num_elems();
402 let mut elems_per_unit = match num_vecs as u32 / simul_vecs {
403 0..2 => 1,
404 2..4 => 2,
405 4..8 => 4,
406 8.. => 8,
407 };
408
409 let mut num_elems_per_unit = line_size as u32 * elems_per_unit;
410
411 let last_dim = output.shape[rank - 1];
412 let packed_dim = input
413 .strides
414 .iter()
415 .enumerate()
416 .rev()
417 .find(|(_, s)| **s == 1)
418 .expect("At least one stride should be 1")
419 .0;
420
421 while !last_dim.is_multiple_of(num_elems_per_unit as usize) {
423 elems_per_unit /= 2;
424 num_elems_per_unit /= 2;
425 }
426
427 let out_layout = LinearLayoutArgs::from_handle(client, output, line_size);
428
429 let cube_dim = CubeDim::default();
430 let cube_count =
431 calculate_cube_count_elemwise(num_elems.div_ceil(num_elems_per_unit as usize), cube_dim);
432
433 let in_shape = shape
434 .iter()
435 .map(|s| FastDivmodArgs::new(client, *s as u32))
436 .collect();
437
438 into_contiguous_kernel_packed::launch::<E, R>(
439 client,
440 cube_count,
441 cube_dim,
442 input.as_tensor_arg(1),
443 output.as_tensor_arg(line_size),
444 out_layout,
445 in_shape,
446 packed_dim as u32,
447 packing,
448 rank as u32,
449 elems_per_unit,
450 );
451}
452
453pub fn is_contiguous(shape: &[usize], strides: &[usize]) -> bool {
455 if shape.is_empty() {
456 return true;
457 }
458
459 for (expected, &stride) in compact_strides(shape).into_iter().zip(strides) {
460 if expected != stride {
461 return false;
462 }
463 }
464
465 true
466}
467
468pub fn is_contiguous_pitched(shape: &[usize], strides: &[usize]) -> bool {
472 let rank = shape.len();
473 if strides[rank - 1] != 1 {
474 return false;
475 }
476 if rank <= 1 {
477 return true;
478 }
479
480 let mut sorted = strides.to_vec();
481 sorted.sort();
482 sorted.reverse();
483
484 if sorted != strides {
485 return false;
486 }
487
488 for i in 0..rank - 2 {
489 if strides[i] != shape[i + 1] * strides[i + 1] {
490 return false;
491 }
492 }
493 true
494}
495
496pub fn compact_strides(shape: &[usize]) -> Vec<usize> {
497 let rank = shape.len();
498 let mut strides = vec![1; rank];
499 for i in (0..rank - 1).rev() {
500 strides[i] = strides[i + 1] * shape[i + 1];
501 }
502 strides
503}