1use crate::{
2 FastDivmod,
3 tensor::{
4 TensorHandle, into_contiguous,
5 layout::{
6 Layout, LayoutExpand,
7 linear::{LinearLayout, LinearView, linear_layout, linear_view},
8 },
9 },
10};
11use cubecl::prelude::*;
12use cubecl_core::{
13 self as cubecl, calculate_cube_count_elemwise,
14 ir::{StorageType, VectorSize},
15 tensor_vector_size_parallel,
16 zspace::{Strides, strides},
17};
18
19pub const NUM_SM_APPROX: u32 = 50;
20
21#[cube]
23pub fn index_offset_with_layout<T: Scalar, N1: Size, L: Scalar, N2: Size>(
24 tensor: &Tensor<Vector<T, N1>>,
25 layout: &Tensor<Vector<L, N2>>,
26 offset_layout: usize,
27 dim_start: usize,
28 dim_end: usize,
29 #[comptime] unroll: bool,
30) -> usize {
31 let offset_ref = offset_layout * tensor.vector_size();
32 let mut offset = 0;
33
34 #[unroll(unroll)]
35 for i in dim_start..dim_end {
36 let ogwl = offset_ref / layout.stride(i);
37 offset += ogwl % tensor.shape(i) * tensor.stride(i);
38 }
39
40 offset / tensor.vector_size()
41}
42
43#[cube]
45pub fn index_offset_contiguous<T: Scalar, N: Size>(
46 tensor: &Tensor<Vector<T, N>>,
47 offset_layout: usize,
48 #[comptime] rank: Option<usize>,
49) -> usize {
50 let unroll = rank.is_some();
51 let rank = rank.unwrap_or_else(|| tensor.rank());
52
53 let offset_ref = offset_layout * tensor.vector_size();
54 let mut offset = 0;
55 let mut remainder = offset_ref;
56
57 #[unroll(unroll)]
58 for i in 0..rank {
59 let dim = rank - i - 1;
60 let shape = tensor.shape(dim);
61 let ogwl = remainder % shape;
62 offset += ogwl * tensor.stride(dim);
63 remainder /= shape;
64 }
65
66 offset / tensor.vector_size()
67}
68
69#[cube]
71pub fn index_offset_contiguous_fastdivmod(
72 offset: usize,
73 shape: &Sequence<FastDivmod<usize>>,
74 stride: &Sequence<usize>,
75 #[comptime] vector_size: VectorSize,
76) -> usize {
77 let rank = shape.len().comptime();
78
79 let offset_ref = offset * vector_size;
80 let mut offset = 0;
81 let mut remainder = offset_ref;
82
83 #[unroll]
84 for i in 0..rank {
85 let dim = rank - i - 1;
86
87 let (rem, ogwl) = shape[dim].div_mod(remainder);
88 offset += ogwl * stride[dim];
89 remainder = rem;
90 }
91
92 offset / vector_size
93}
94
95#[cube(launch, address_type = "dynamic")]
96fn copy_kernel<T: Numeric, N: Size>(
97 input: &LinearView<Vector<T, N>>,
98 output: &mut Tensor<Vector<T, N>>,
99 out_layout: LinearLayout,
100 #[comptime] elems_per_thread: usize,
101 #[define(T)] _elem: StorageType,
102) {
103 let offset_linear = ABSOLUTE_POS * elems_per_thread;
104
105 let mut registers = Array::<Vector<T, N>>::new(elems_per_thread);
106
107 #[unroll]
108 for i in 0..elems_per_thread {
109 registers[i] = input[offset_linear + i];
110 }
111
112 let offset_output = out_layout.to_source_pos(offset_linear);
113
114 #[unroll]
115 for i in 0..elems_per_thread {
116 output[offset_output + i] = registers[i];
117 }
118}
119
120#[cube(launch, address_type = "dynamic")]
121fn copy_kernel_pack<T: Numeric, N: Size>(
122 input: &LinearView<T>,
123 output: &mut Tensor<Vector<T, N>>,
124 out_layout: LinearLayout,
125 #[comptime] elems_per_thread: usize,
126 #[define(T)] _elem: StorageType,
127) {
128 let vector_size = output.vector_size().comptime();
129 let vectors_per_thread = elems_per_thread / vector_size;
130
131 let offset_output = ABSOLUTE_POS * vectors_per_thread;
132 let offset_input = offset_output * vector_size;
133
134 let mut registers = Array::<Vector<T, N>>::new(vectors_per_thread);
135
136 #[unroll]
137 for i in 0..vectors_per_thread {
138 let offset = i * vector_size;
139 let mut reg = Vector::<T, N>::empty();
140 #[unroll]
141 for k in 0..vector_size {
142 let offset_input = offset_input + offset + k;
143 reg[k] = input[offset_input];
144 }
145 registers[i] = reg;
146 }
147
148 let offset_output = out_layout.to_source_pos(offset_output);
149
150 #[unroll]
151 for i in 0..vectors_per_thread {
152 output[offset_output + i] = registers[i];
153 }
154}
155
156#[cube]
159fn index_packed<N: Int>(
160 tensor: &Tensor<N>,
161 pos: usize,
162 in_shape: &Sequence<FastDivmod<usize>>,
163 #[comptime] packed_dim: usize,
164 #[comptime] packing: usize,
165 #[comptime] rank: usize,
166) -> N {
167 let type_size_bits = N::type_size_bits().comptime();
168 let bits_per_elem = type_size_bits / packing;
169 let mask = (1u32 << bits_per_elem) - 1;
170 let mask = N::cast_from(mask);
171
172 let elem_pos = pos * packing;
173
174 let mut out = N::new(0);
175 for n in 0..packing {
176 let mut remainder = elem_pos + n;
177 let mut offset = 0;
178 let mut packing_offset = 0;
179
180 #[unroll]
181 for i in 0..rank {
182 let dim = rank - i - 1;
183 let (rem, mut local_pos) = in_shape[dim].div_mod(remainder);
184 remainder = rem;
185 if dim == packed_dim {
186 packing_offset = local_pos % packing;
187 local_pos /= packing;
188 }
189 offset += local_pos * tensor.stride(dim);
190 }
191 let packed_val = tensor[offset];
192 let shift_in = packing_offset * bits_per_elem;
193 let shift_out = n * bits_per_elem;
194 let value = (packed_val >> N::cast_from(shift_in)) & mask;
195
196 out |= value << N::cast_from(shift_out);
197 }
198 out
199}
200
201#[cube(launch, address_type = "dynamic")]
202fn copy_kernel_packed<T: Int, N: Size>(
203 input: &Tensor<T>,
204 output: &mut Tensor<Vector<T, N>>,
205 out_layout: LinearLayout,
206 in_shape: Sequence<FastDivmod<usize>>,
207 #[comptime] packed_dim: usize,
208 #[comptime] packing: usize,
209 #[comptime] rank: usize,
210 #[comptime] elems_per_thread: usize,
211 #[define(T)] _elem: StorageType,
212) {
213 let vector_size = output.vector_size().comptime();
214 let vectors_per_thread = elems_per_thread / vector_size;
215
216 let offset_output = ABSOLUTE_POS * vectors_per_thread;
217 let offset_input = offset_output * vector_size;
218
219 if offset_output >= output.len() {
220 terminate!()
221 }
222
223 let mut registers = Array::<Vector<T, N>>::new(vectors_per_thread);
224
225 #[unroll]
226 for i in 0..vectors_per_thread {
227 let offset = i * vector_size;
228 let mut reg = Vector::<T, N>::empty();
229 #[unroll]
230 for k in 0..vector_size {
231 let offset_input = offset_input + offset + k;
232
233 reg[k] = index_packed(input, offset_input, &in_shape, packed_dim, packing, rank);
234 }
235 registers[i] = reg;
236 }
237
238 let offset_output = out_layout.to_source_pos(offset_output);
239
240 #[unroll]
241 for i in 0..vectors_per_thread {
242 output[offset_output + i] = registers[i];
243 }
244}
245
246pub fn into_contiguous_packed<R: Runtime>(
255 client: &ComputeClient<R>,
256 input: TensorBinding<R>,
257 packed_dim: usize,
258 shape: &[usize],
259 packing: usize,
260 dtype: StorageType,
261) -> TensorHandle<R> {
262 let rank = shape.len();
263 if rank <= 1 {
264 return into_contiguous(client, input, dtype);
265 }
266
267 let mut out_shape = shape.to_vec();
268 out_shape[rank - 1] = out_shape[rank - 1].div_ceil(packing);
269 let output = TensorHandle::empty(client, out_shape, dtype);
270
271 into_contiguous_packed_ref(
274 client,
275 input,
276 output.clone().binding(),
277 packed_dim,
278 shape,
279 packing,
280 dtype,
281 );
282
283 output
284}
285
286pub fn copy_gpu_ref<R: Runtime>(
288 client: &ComputeClient<R>,
289 input: TensorBinding<R>,
290 output: TensorBinding<R>,
291 dtype: StorageType,
292) {
293 let num_elems: usize = input.shape.iter().product();
294
295 let in_rank = input.strides.len();
297 let out_rank = output.strides.len();
298 let vector_size_in = tensor_vector_size_parallel(
299 client.io_optimized_vector_sizes(dtype.size()),
300 &input.shape,
301 &input.strides,
302 in_rank - 1,
303 );
304 let vector_size_out = tensor_vector_size_parallel(
305 client.io_optimized_vector_sizes(dtype.size()),
306 &output.shape,
307 &output.strides,
308 out_rank - 1,
309 );
310 let vector_size = vector_size_in.min(vector_size_out);
311
312 let num_vecs = num_elems / vector_size as usize;
313 let num_sm = client
314 .properties()
315 .hardware
316 .num_streaming_multiprocessors
317 .unwrap_or(NUM_SM_APPROX);
318 let cube_dim = CubeDim::new(client, num_vecs);
319 let simul_vecs = num_sm * cube_dim.num_elems();
320 let mut elems_per_unit = match num_vecs / simul_vecs as usize {
321 0..2 => 1,
322 2..4 => 2,
323 4..8 => 4,
324 8.. => 8,
325 };
326
327 let mut num_elems_per_unit = vector_size as usize * elems_per_unit;
328
329 let last_dim = output.shape[out_rank - 1];
330
331 while !last_dim.is_multiple_of(num_elems_per_unit as usize) {
333 elems_per_unit /= 2;
334 num_elems_per_unit /= 2;
335 }
336
337 let out_vec = if vector_size > 1 {
338 vector_size
339 } else {
340 client
342 .io_optimized_vector_sizes(dtype.size())
343 .filter(|it| num_elems_per_unit.is_multiple_of(*it))
344 .max()
345 .unwrap_or(1)
346 };
347
348 let address_type = input
349 .required_address_type(dtype.size())
350 .max(output.required_address_type(dtype.size()));
351 let input = linear_view(input);
352 let out_layout = linear_layout(&output, out_vec);
353
354 let cube_count = calculate_cube_count_elemwise(
355 client,
356 num_elems.div_ceil(num_elems_per_unit as usize),
357 cube_dim,
358 );
359
360 let launch = if vector_size != out_vec && out_vec > 1 {
361 copy_kernel_pack::launch
362 } else {
363 copy_kernel::launch
364 };
365
366 launch(
367 client,
368 cube_count,
369 cube_dim,
370 address_type,
371 out_vec,
372 input,
373 output.into_tensor_arg(),
374 out_layout,
375 elems_per_unit,
376 dtype,
377 )
378}
379
380pub fn into_contiguous_packed_ref<R: Runtime>(
382 client: &ComputeClient<R>,
383 input: TensorBinding<R>,
384 output: TensorBinding<R>,
385 packed_dim: usize,
386 shape: &[usize],
387 packing: usize,
388 dtype: StorageType,
389) {
390 let num_elems: usize = input.shape.iter().product();
391
392 let in_rank = input.strides.len();
394 let out_rank = output.strides.len();
395 let in_packed_dim = in_rank - packed_dim - 1;
396 let vector_size = tensor_vector_size_parallel(
397 client.io_optimized_vector_sizes(dtype.size()),
398 &output.shape,
399 &output.strides,
400 out_rank - 1,
401 );
402 let num_vecs = num_elems / vector_size as usize;
403 let num_sm = client
404 .properties()
405 .hardware
406 .num_streaming_multiprocessors
407 .unwrap_or(NUM_SM_APPROX);
408
409 let cube_dim = CubeDim::new(client, num_vecs);
410 let simul_vecs = num_sm * cube_dim.num_elems();
411 let mut elems_per_unit = match num_vecs / simul_vecs as usize {
412 0..2 => 1,
413 2..4 => 2,
414 4..8 => 4,
415 8.. => 8,
416 };
417
418 let mut num_elems_per_unit = vector_size as usize * elems_per_unit;
419
420 let last_dim = output.shape[out_rank - 1];
421
422 while !last_dim.is_multiple_of(num_elems_per_unit as usize) {
424 elems_per_unit /= 2;
425 num_elems_per_unit /= 2;
426 }
427
428 let out_layout = linear_layout(&output, vector_size);
429
430 let address_type = input
431 .required_address_type(dtype.size())
432 .max(output.required_address_type(dtype.size()));
433 let cube_count = calculate_cube_count_elemwise(
434 client,
435 num_elems.div_ceil(num_elems_per_unit as usize),
436 cube_dim,
437 );
438
439 let in_shape = shape.iter().copied().collect();
440
441 copy_kernel_packed::launch(
442 client,
443 cube_count,
444 cube_dim,
445 address_type,
446 vector_size,
447 input.into_tensor_arg(),
448 output.into_tensor_arg(),
449 out_layout,
450 in_shape,
451 in_packed_dim,
452 packing,
453 in_rank,
454 elems_per_unit,
455 dtype,
456 )
457}
458
459pub fn is_contiguous(shape: &[usize], strides: &[usize]) -> bool {
461 if shape.is_empty() {
462 return true;
463 }
464
465 for (&expected, &stride) in compact_strides(shape).iter().zip(strides) {
466 if expected != stride {
467 return false;
468 }
469 }
470
471 true
472}
473
474pub fn is_contiguous_pitched(shape: &[usize], strides: &[usize]) -> bool {
478 let rank = shape.len();
479 if strides[rank - 1] != 1 {
480 return false;
481 }
482 if rank <= 1 {
483 return true;
484 }
485
486 let mut sorted = strides.to_vec();
487 sorted.sort();
488 sorted.reverse();
489
490 if sorted != strides {
491 return false;
492 }
493
494 for i in 0..rank - 2 {
495 if strides[i] != shape[i + 1] * strides[i + 1] {
496 return false;
497 }
498 }
499 true
500}
501
502pub fn compact_strides(shape: &[usize]) -> Strides {
503 let rank = shape.len();
504 let mut strides = strides![1; rank];
505 for i in (0..rank - 1).rev() {
506 strides[i] = strides[i + 1] * shape[i + 1];
507 }
508 strides
509}