1use crate::{
2 FastDivmod, FastDivmodArgs,
3 tensor::layout::{
4 Layout, LayoutExpand,
5 linear::{LinearLayout, LinearLayoutArgs, LinearView, linear_view},
6 },
7};
8
9use super::TensorHandle;
10use cubecl::prelude::*;
11use cubecl_core::{
12 self as cubecl, calculate_cube_count_elemwise, ir::StorageType, tensor_line_size_parallel,
13};
14
15pub const NUM_SM_APPROX: u32 = 50;
16
17#[cube]
19pub fn index_offset_with_layout<N: CubePrimitive, L: CubePrimitive>(
20 tensor: &Tensor<Line<N>>,
21 layout: &Tensor<Line<L>>,
22 offset_layout: u32,
23 dim_start: u32,
24 dim_end: u32,
25 #[comptime] unroll: bool,
26) -> u32 {
27 let offset_ref = offset_layout * tensor.line_size();
28 let mut offset = 0;
29
30 #[unroll(unroll)]
31 for i in dim_start..dim_end {
32 let ogwl = offset_ref / layout.stride(i);
33 offset += ogwl % tensor.shape(i) * tensor.stride(i);
34 }
35
36 offset / tensor.line_size()
37}
38
39#[cube]
41pub fn index_offset_contiguous<N: CubePrimitive>(
42 tensor: &Tensor<Line<N>>,
43 offset_layout: u32,
44 #[comptime] rank: Option<u32>,
45) -> u32 {
46 let unroll = rank.is_some();
47 let rank = rank.unwrap_or_else(|| tensor.rank());
48
49 let offset_ref = offset_layout * tensor.line_size();
50 let mut offset = 0;
51 let mut remainder = offset_ref;
52
53 #[unroll(unroll)]
54 for i in 0..rank {
55 let dim = rank - i - 1;
56 let shape = tensor.shape(dim);
57 let ogwl = remainder % shape;
58 offset += ogwl * tensor.stride(dim);
59 remainder /= shape;
60 }
61
62 offset / tensor.line_size()
63}
64
65#[cube]
67pub fn index_offset_contiguous_fastdivmod(
68 offset: u32,
69 shape: &Sequence<FastDivmod>,
70 stride: &Sequence<u32>,
71 #[comptime] line_size: u32,
72) -> u32 {
73 let rank = comptime![shape.len()];
74
75 let offset_ref = offset * line_size;
76 let mut offset = 0;
77 let mut remainder = offset_ref;
78
79 let mut dim = comptime![rank - 1];
80
81 #[unroll]
82 for _ in 0..rank {
83 let shape = shape.index(dim);
84 let (rem, ogwl) = shape.div_mod(remainder);
85 offset += ogwl * stride.index(dim);
86 remainder = rem;
87
88 comptime![dim = dim.saturating_sub(1);]
89 }
90
91 offset / line_size
92}
93
94#[cube(launch)]
95fn into_contiguous_kernel<N: Numeric>(
96 input: &LinearView<Line<N>>,
97 output: &mut Tensor<Line<N>>,
98 out_layout: LinearLayout,
99 #[comptime] elems_per_thread: u32,
100 #[define(N)] _elem: StorageType,
101) {
102 let offset_output = ABSOLUTE_POS * elems_per_thread;
103 let line_size = input.line_size();
104
105 let mut registers = Array::<Line<N>>::vectorized(elems_per_thread, line_size);
106
107 #[unroll]
108 for i in 0..elems_per_thread {
109 registers[i] = input[offset_output + i];
110 }
111
112 let offset_output = out_layout.to_source_pos(offset_output);
113
114 #[unroll]
115 for i in 0..elems_per_thread {
116 output[offset_output + i] = registers[i];
117 }
118}
119
120#[cube(launch)]
121fn into_contiguous_kernel_pack<N: Numeric>(
122 input: &LinearView<Line<N>>,
123 output: &mut Tensor<Line<N>>,
124 out_layout: LinearLayout,
125 #[comptime] elems_per_thread: u32,
126 #[define(N)] _elem: StorageType,
127) {
128 let line_size = output.line_size();
129 let lines_per_thread = comptime![elems_per_thread / line_size];
130
131 let offset_output = ABSOLUTE_POS * lines_per_thread;
132 let offset_input = offset_output * line_size;
133
134 let mut registers = Array::<Line<N>>::vectorized(lines_per_thread, line_size);
135
136 #[unroll]
137 for i in 0..lines_per_thread {
138 let offset = i * line_size;
139 let mut reg = Line::<N>::empty(line_size);
140 #[unroll]
141 for k in 0..line_size {
142 let offset_input = offset_input + offset + k;
143 reg[k] = input[offset_input][0];
144 }
145 registers[i] = reg;
146 }
147
148 let offset_output = out_layout.to_source_pos(offset_output);
149
150 #[unroll]
151 for i in 0..lines_per_thread {
152 output[offset_output + i] = registers[i];
153 }
154}
155
156#[cube]
159fn index_packed<N: Int>(
160 tensor: &Tensor<N>,
161 pos: u32,
162 in_shape: &Sequence<FastDivmod>,
163 #[comptime] packed_dim: u32,
164 #[comptime] packing: u32,
165 #[comptime] rank: u32,
166) -> N {
167 let type_size_bits = N::type_size_bits();
168 let bits_per_elem = comptime![type_size_bits / packing];
169 let mask = comptime![(1u32 << bits_per_elem) - 1];
170 let mask = N::cast_from(mask);
171
172 let elem_pos = pos * packing;
173
174 let mut out = N::new(0);
175 for n in 0..packing {
176 let mut remainder = elem_pos + n;
177 let mut offset = 0;
178 let mut packing_offset = 0;
179
180 #[unroll]
181 for i in 0..rank {
182 let dim = comptime![rank - i - 1];
183 let (rem, mut local_pos) = in_shape.index(dim).div_mod(remainder);
184 remainder = rem;
185 if comptime![dim == packed_dim] {
186 packing_offset = local_pos % packing;
187 local_pos /= packing;
188 }
189 offset += local_pos * tensor.stride(dim);
190 }
191 let packed_val = tensor[offset];
192 let shift_in = packing_offset * bits_per_elem;
193 let shift_out = n * bits_per_elem;
194 let value = (packed_val >> N::cast_from(shift_in)) & mask;
195
196 out |= value << N::cast_from(shift_out);
197 }
198 out
199}
200
201#[cube(launch)]
202fn into_contiguous_kernel_packed<N: Int>(
203 input: &Tensor<N>,
204 output: &mut Tensor<Line<N>>,
205 out_layout: LinearLayout,
206 in_shape: Sequence<FastDivmod>,
207 #[comptime] packed_dim: u32,
208 #[comptime] packing: u32,
209 #[comptime] rank: u32,
210 #[comptime] elems_per_thread: u32,
211 #[define(N)] _elem: StorageType,
212) {
213 let line_size = output.line_size();
214 let lines_per_thread = comptime![elems_per_thread / line_size];
215
216 let offset_output = ABSOLUTE_POS * lines_per_thread;
217 let offset_input = offset_output * line_size;
218
219 if offset_output >= output.len() {
220 terminate!()
221 }
222
223 let mut registers = Array::<Line<N>>::vectorized(lines_per_thread, line_size);
224
225 #[unroll]
226 for i in 0..lines_per_thread {
227 let offset = i * line_size;
228 let mut reg = Line::<N>::empty(line_size);
229 #[unroll]
230 for k in 0..line_size {
231 let offset_input = offset_input + offset + k;
232
233 reg[k] = index_packed(input, offset_input, &in_shape, packed_dim, packing, rank);
234 }
235 registers[i] = reg;
236 }
237
238 let offset_output = out_layout.to_source_pos(offset_output);
239
240 #[unroll]
241 for i in 0..lines_per_thread {
242 output[offset_output + i] = registers[i];
243 }
244}
245
246pub fn into_contiguous<R: Runtime>(
248 client: &ComputeClient<R>,
249 input: &TensorHandleRef<'_, R>,
250 dtype: StorageType,
251) -> Result<TensorHandle<R>, LaunchError> {
252 let num_elems: usize = input.shape.iter().product();
253
254 let handle = client.empty(num_elems * dtype.size());
255 let output = TensorHandle::new_contiguous(input.shape.to_vec(), handle, dtype);
256
257 into_contiguous_ref(client, input, &output.as_ref(), dtype)?;
258
259 Ok(output)
260}
261
262pub fn into_contiguous_pitched<R: Runtime>(
265 client: &ComputeClient<R>,
266 input: &TensorHandleRef<'_, R>,
267 dtype: StorageType,
268) -> Result<TensorHandle<R>, LaunchError> {
269 if input.shape.len() <= 1 {
270 return into_contiguous(client, input, dtype);
271 }
272
273 let output = TensorHandle::empty(client, input.shape.to_vec(), dtype);
274
275 into_contiguous_ref(client, input, &output.as_ref(), dtype)?;
276
277 Ok(output)
278}
279
280pub fn into_contiguous_packed<R: Runtime>(
289 client: &ComputeClient<R>,
290 input: &TensorHandleRef<'_, R>,
291 shape: &[usize],
292 packing: u32,
293 dtype: StorageType,
294) -> Result<TensorHandle<R>, LaunchError> {
295 let rank = shape.len();
296 if rank <= 1 {
297 return into_contiguous(client, input, dtype);
298 }
299
300 let mut out_shape = shape.to_vec();
301 out_shape[rank - 1] = out_shape[rank - 1].div_ceil(packing as usize);
302 let output = TensorHandle::empty(client, out_shape, dtype);
303
304 into_contiguous_packed_ref(client, input, &output.as_ref(), shape, packing, dtype)?;
307
308 Ok(output)
309}
310
311pub fn into_contiguous_ref<R: Runtime>(
313 client: &ComputeClient<R>,
314 input: &TensorHandleRef<'_, R>,
315 output: &TensorHandleRef<'_, R>,
316 dtype: StorageType,
317) -> Result<(), LaunchError> {
318 let num_elems: usize = input.shape.iter().product();
319
320 let rank = input.strides.len();
322 let line_size = tensor_line_size_parallel(
323 R::supported_line_sizes().iter().cloned(),
324 input.shape,
325 input.strides,
326 rank - 1,
327 );
328 let num_vecs = num_elems / line_size as usize;
329 let num_sm = client
330 .properties()
331 .hardware
332 .num_streaming_multiprocessors
333 .unwrap_or(NUM_SM_APPROX);
334 let cube_dim = CubeDim::new(client, num_vecs);
335 let simul_vecs = num_sm * cube_dim.num_elems();
336 let mut elems_per_unit = match num_vecs as u32 / simul_vecs {
337 0..2 => 1,
338 2..4 => 2,
339 4..8 => 4,
340 8.. => 8,
341 };
342
343 let mut num_elems_per_unit = line_size as u32 * elems_per_unit;
344
345 let last_dim = output.shape[rank - 1];
346
347 while !last_dim.is_multiple_of(num_elems_per_unit as usize) {
349 elems_per_unit /= 2;
350 num_elems_per_unit /= 2;
351 }
352
353 let out_vec = if line_size > 1 {
354 line_size
355 } else {
356 *R::supported_line_sizes()
357 .iter()
358 .filter(|it| num_elems_per_unit.is_multiple_of(**it as u32))
359 .max()
360 .unwrap_or(&1)
361 };
362
363 let input = linear_view(client, input, line_size);
364 let out_layout = LinearLayoutArgs::from_handle(client, output, out_vec);
365
366 let cube_count = calculate_cube_count_elemwise(
367 client,
368 num_elems.div_ceil(num_elems_per_unit as usize),
369 cube_dim,
370 );
371
372 let launch = if line_size != out_vec && out_vec > 1 {
373 into_contiguous_kernel_pack::launch
374 } else {
375 into_contiguous_kernel::launch
376 };
377
378 launch(
379 client,
380 cube_count,
381 cube_dim,
382 input,
383 output.as_tensor_arg(out_vec),
384 out_layout,
385 elems_per_unit,
386 dtype,
387 )
388}
389
390pub fn into_contiguous_packed_ref<R: Runtime>(
392 client: &ComputeClient<R>,
393 input: &TensorHandleRef<'_, R>,
394 output: &TensorHandleRef<'_, R>,
395 shape: &[usize],
396 packing: u32,
397 dtype: StorageType,
398) -> Result<(), LaunchError> {
399 let num_elems: usize = input.shape.iter().product();
400
401 let rank = input.strides.len();
403 let line_size = tensor_line_size_parallel(
404 client.io_optimized_line_sizes(&dtype),
405 output.shape,
406 output.strides,
407 rank - 1,
408 );
409 let num_vecs = num_elems / line_size as usize;
410 let num_sm = client
411 .properties()
412 .hardware
413 .num_streaming_multiprocessors
414 .unwrap_or(NUM_SM_APPROX);
415 let cube_dim = CubeDim::new(client, num_vecs);
416 let simul_vecs = num_sm * cube_dim.num_elems();
417 let mut elems_per_unit = match num_vecs as u32 / simul_vecs {
418 0..2 => 1,
419 2..4 => 2,
420 4..8 => 4,
421 8.. => 8,
422 };
423
424 let mut num_elems_per_unit = line_size as u32 * elems_per_unit;
425
426 let last_dim = output.shape[rank - 1];
427 let packed_dim = input
428 .strides
429 .iter()
430 .enumerate()
431 .rev()
432 .find(|(_, s)| **s == 1)
433 .expect("At least one stride should be 1")
434 .0;
435
436 while !last_dim.is_multiple_of(num_elems_per_unit as usize) {
438 elems_per_unit /= 2;
439 num_elems_per_unit /= 2;
440 }
441
442 let out_layout = LinearLayoutArgs::from_handle(client, output, line_size);
443
444 let cube_count = calculate_cube_count_elemwise(
445 client,
446 num_elems.div_ceil(num_elems_per_unit as usize),
447 cube_dim,
448 );
449
450 let in_shape = shape
451 .iter()
452 .map(|s| FastDivmodArgs::new(client, *s as u32))
453 .collect();
454
455 into_contiguous_kernel_packed::launch(
456 client,
457 cube_count,
458 cube_dim,
459 input.as_tensor_arg(1),
460 output.as_tensor_arg(line_size),
461 out_layout,
462 in_shape,
463 packed_dim as u32,
464 packing,
465 rank as u32,
466 elems_per_unit,
467 dtype,
468 )
469}
470
471pub fn is_contiguous(shape: &[usize], strides: &[usize]) -> bool {
473 if shape.is_empty() {
474 return true;
475 }
476
477 for (expected, &stride) in compact_strides(shape).into_iter().zip(strides) {
478 if expected != stride {
479 return false;
480 }
481 }
482
483 true
484}
485
486pub fn is_contiguous_pitched(shape: &[usize], strides: &[usize]) -> bool {
490 let rank = shape.len();
491 if strides[rank - 1] != 1 {
492 return false;
493 }
494 if rank <= 1 {
495 return true;
496 }
497
498 let mut sorted = strides.to_vec();
499 sorted.sort();
500 sorted.reverse();
501
502 if sorted != strides {
503 return false;
504 }
505
506 for i in 0..rank - 2 {
507 if strides[i] != shape[i + 1] * strides[i + 1] {
508 return false;
509 }
510 }
511 true
512}
513
514pub fn compact_strides(shape: &[usize]) -> Vec<usize> {
515 let rank = shape.len();
516 let mut strides = vec![1; rank];
517 for i in (0..rank - 1).rev() {
518 strides[i] = strides[i + 1] * shape[i + 1];
519 }
520 strides
521}