1use crate::{
2 FastDivmod, FastDivmodArgs,
3 tensor::layout::{
4 Layout, LayoutExpand,
5 linear::{LinearLayout, LinearLayoutArgs, LinearView, linear_view},
6 },
7};
8
9use super::TensorHandle;
10use cubecl::prelude::*;
11use cubecl_core::{self as cubecl, calculate_cube_count_elemwise, tensor_line_size_parallel};
12
13pub const NUM_SM_APPROX: u32 = 50;
14
15#[cube]
17pub fn index_offset_with_layout<N: CubePrimitive, L: CubePrimitive>(
18 tensor: &Tensor<Line<N>>,
19 layout: &Tensor<Line<L>>,
20 offset_layout: u32,
21 dim_start: u32,
22 dim_end: u32,
23 #[comptime] unroll: bool,
24) -> u32 {
25 let offset_ref = offset_layout * tensor.line_size();
26 let mut offset = 0;
27
28 #[unroll(unroll)]
29 for i in dim_start..dim_end {
30 let ogwl = offset_ref / layout.stride(i);
31 offset += ogwl % tensor.shape(i) * tensor.stride(i);
32 }
33
34 offset / tensor.line_size()
35}
36
37#[cube]
39pub fn index_offset_contiguous<N: CubePrimitive>(
40 tensor: &Tensor<Line<N>>,
41 offset_layout: u32,
42 #[comptime] rank: Option<u32>,
43) -> u32 {
44 let unroll = rank.is_some();
45 let rank = rank.unwrap_or_else(|| tensor.rank());
46
47 let offset_ref = offset_layout * tensor.line_size();
48 let mut offset = 0;
49 let mut remainder = offset_ref;
50
51 #[unroll(unroll)]
52 for i in 0..rank {
53 let dim = rank - i - 1;
54 let shape = tensor.shape(dim);
55 let ogwl = remainder % shape;
56 offset += ogwl * tensor.stride(dim);
57 remainder /= shape;
58 }
59
60 offset / tensor.line_size()
61}
62
63#[cube]
65pub fn index_offset_contiguous_fastdivmod(
66 offset: u32,
67 shape: &Sequence<FastDivmod>,
68 stride: &Sequence<u32>,
69 #[comptime] line_size: u32,
70) -> u32 {
71 let rank = comptime![shape.len()];
72
73 let offset_ref = offset * line_size;
74 let mut offset = 0;
75 let mut remainder = offset_ref;
76
77 let mut dim = comptime![rank - 1];
78
79 #[unroll]
80 for _ in 0..rank {
81 let shape = shape.index(dim);
82 let (rem, ogwl) = shape.div_mod(remainder);
83 offset += ogwl * stride.index(dim);
84 remainder = rem;
85
86 comptime![dim = dim.saturating_sub(1);]
87 }
88
89 offset / line_size
90}
91
92#[cube(launch)]
93fn into_contiguous_kernel<N: CubePrimitive>(
94 input: &LinearView<Line<N>>,
95 output: &mut Tensor<Line<N>>,
96 out_layout: LinearLayout,
97 #[comptime] elems_per_thread: u32,
98) {
99 let offset_output = ABSOLUTE_POS * elems_per_thread;
100 let line_size = input.line_size();
101
102 let mut registers = Array::<Line<N>>::vectorized(elems_per_thread, line_size);
103
104 #[unroll]
105 for i in 0..elems_per_thread {
106 registers[i] = input[offset_output + i];
107 }
108
109 let offset_output = out_layout.to_source_pos(offset_output);
110
111 #[unroll]
112 for i in 0..elems_per_thread {
113 output[offset_output + i] = registers[i];
114 }
115}
116
117#[cube(launch)]
118fn into_contiguous_kernel_pack<N: CubePrimitive>(
119 input: &LinearView<Line<N>>,
120 output: &mut Tensor<Line<N>>,
121 out_layout: LinearLayout,
122 #[comptime] elems_per_thread: u32,
123) {
124 let line_size = output.line_size();
125 let lines_per_thread = comptime![elems_per_thread / line_size];
126
127 let offset_output = ABSOLUTE_POS * lines_per_thread;
128 let offset_input = offset_output * line_size;
129
130 let mut registers = Array::<Line<N>>::vectorized(lines_per_thread, line_size);
131
132 #[unroll]
133 for i in 0..lines_per_thread {
134 let offset = i * line_size;
135 let mut reg = Line::<N>::empty(line_size);
136 #[unroll]
137 for k in 0..line_size {
138 let offset_input = offset_input + offset + k;
139 reg[k] = input[offset_input][0];
140 }
141 registers[i] = reg;
142 }
143
144 let offset_output = out_layout.to_source_pos(offset_output);
145
146 #[unroll]
147 for i in 0..lines_per_thread {
148 output[offset_output + i] = registers[i];
149 }
150}
151
152#[cube]
155fn index_packed<N: Int>(
156 tensor: &Tensor<N>,
157 pos: u32,
158 in_shape: &Sequence<FastDivmod>,
159 #[comptime] packed_dim: u32,
160 #[comptime] packing: u32,
161 #[comptime] rank: u32,
162) -> N {
163 let bits_per_elem = comptime![N::elem_size_bits() / packing];
164 let mask = comptime![(1u32 << bits_per_elem) - 1];
165 let mask = N::cast_from(mask);
166
167 let elem_pos = pos * packing;
168
169 let mut out = N::new(0);
170 for n in 0..packing {
171 let mut remainder = elem_pos + n;
172 let mut offset = 0;
173 let mut packing_offset = 0;
174
175 #[unroll]
176 for i in 0..rank {
177 let dim = comptime![rank - i - 1];
178 let (rem, mut local_pos) = in_shape.index(dim).div_mod(remainder);
179 remainder = rem;
180 if comptime![dim == packed_dim] {
181 packing_offset = local_pos % packing;
182 local_pos /= packing;
183 }
184 offset += local_pos * tensor.stride(dim);
185 }
186 let packed_val = tensor[offset];
187 let shift_in = packing_offset * bits_per_elem;
188 let shift_out = n * bits_per_elem;
189 let value = (packed_val >> N::cast_from(shift_in)) & mask;
190
191 out |= value << N::cast_from(shift_out);
192 }
193 out
194}
195
196#[cube(launch)]
197fn into_contiguous_kernel_packed<N: Int>(
198 input: &Tensor<N>,
199 output: &mut Tensor<Line<N>>,
200 out_layout: LinearLayout,
201 in_shape: Sequence<FastDivmod>,
202 #[comptime] packed_dim: u32,
203 #[comptime] packing: u32,
204 #[comptime] rank: u32,
205 #[comptime] elems_per_thread: u32,
206) {
207 let line_size = output.line_size();
208 let lines_per_thread = comptime![elems_per_thread / line_size];
209
210 let offset_output = ABSOLUTE_POS * lines_per_thread;
211 let offset_input = offset_output * line_size;
212
213 if offset_output >= output.len() {
214 terminate!()
215 }
216
217 let mut registers = Array::<Line<N>>::vectorized(lines_per_thread, line_size);
218
219 #[unroll]
220 for i in 0..lines_per_thread {
221 let offset = i * line_size;
222 let mut reg = Line::<N>::empty(line_size);
223 #[unroll]
224 for k in 0..line_size {
225 let offset_input = offset_input + offset + k;
226
227 reg[k] = index_packed(input, offset_input, &in_shape, packed_dim, packing, rank);
228 }
229 registers[i] = reg;
230 }
231
232 let offset_output = out_layout.to_source_pos(offset_output);
233
234 #[unroll]
235 for i in 0..lines_per_thread {
236 output[offset_output + i] = registers[i];
237 }
238}
239
240pub fn into_contiguous<R: Runtime, E: CubePrimitive>(
242 client: &ComputeClient<R::Server>,
243 input: &TensorHandleRef<'_, R>,
244) -> TensorHandle<R, E> {
245 let num_elems: usize = input.shape.iter().product();
246
247 let handle = client.empty(num_elems * size_of::<E>());
248 let output = TensorHandle::new_contiguous(input.shape.to_vec(), handle);
249
250 into_contiguous_ref::<R, E>(client, input, &output.as_ref());
251
252 output
253}
254
255pub fn into_contiguous_pitched<R: Runtime, E: CubePrimitive>(
258 client: &ComputeClient<R::Server>,
259 input: &TensorHandleRef<'_, R>,
260) -> TensorHandle<R, E> {
261 if input.shape.len() <= 1 {
262 return into_contiguous(client, input);
263 }
264
265 let output = TensorHandle::empty(client, input.shape.to_vec());
266
267 into_contiguous_ref::<R, E>(client, input, &output.as_ref());
268
269 output
270}
271
272pub fn into_contiguous_packed<R: Runtime, I: Int>(
281 client: &ComputeClient<R::Server>,
282 input: &TensorHandleRef<'_, R>,
283 shape: &[usize],
284 packing: u32,
285) -> TensorHandle<R, I> {
286 let rank = shape.len();
287 if rank <= 1 {
288 return into_contiguous(client, input);
289 }
290
291 let mut out_shape = shape.to_vec();
292 out_shape[rank - 1] = out_shape[rank - 1].div_ceil(packing as usize);
293 let output = TensorHandle::<R, I>::empty(client, out_shape);
294
295 into_contiguous_packed_ref::<R, I>(client, input, &output.as_ref(), shape, packing);
298
299 output
300}
301
302pub fn into_contiguous_ref<R: Runtime, E: CubePrimitive>(
304 client: &ComputeClient<R::Server>,
305 input: &TensorHandleRef<'_, R>,
306 output: &TensorHandleRef<'_, R>,
307) {
308 let num_elems: usize = input.shape.iter().product();
309
310 let rank = input.strides.len();
312 let line_size = tensor_line_size_parallel(
313 R::supported_line_sizes().iter().cloned(),
314 input.shape,
315 input.strides,
316 rank - 1,
317 );
318 let num_vecs = num_elems / line_size as usize;
319 let num_sm = client
320 .properties()
321 .hardware
322 .num_streaming_multiprocessors
323 .unwrap_or(NUM_SM_APPROX);
324 let simul_vecs = num_sm * CubeDim::default().num_elems();
325 let mut elems_per_unit = match num_vecs as u32 / simul_vecs {
326 0..2 => 1,
327 2..4 => 2,
328 4..8 => 4,
329 8.. => 8,
330 };
331
332 let mut num_elems_per_unit = line_size as u32 * elems_per_unit;
333
334 let last_dim = output.shape[rank - 1];
335
336 while !last_dim.is_multiple_of(num_elems_per_unit as usize) {
338 elems_per_unit /= 2;
339 num_elems_per_unit /= 2;
340 }
341
342 let out_vec = if line_size > 1 {
343 line_size
344 } else {
345 *R::supported_line_sizes()
346 .iter()
347 .filter(|it| num_elems_per_unit.is_multiple_of(**it as u32))
348 .max()
349 .unwrap_or(&1)
350 };
351
352 let input = linear_view(client, input, line_size);
353 let out_layout = LinearLayoutArgs::from_handle(client, output, out_vec);
354
355 let cube_dim = CubeDim::default();
356 let cube_count =
357 calculate_cube_count_elemwise(num_elems.div_ceil(num_elems_per_unit as usize), cube_dim);
358
359 let launch = if line_size != out_vec && out_vec > 1 {
360 into_contiguous_kernel_pack::launch::<E, R>
361 } else {
362 into_contiguous_kernel::launch::<E, R>
363 };
364
365 launch(
366 client,
367 cube_count,
368 cube_dim,
369 input,
370 output.as_tensor_arg(out_vec),
371 out_layout,
372 elems_per_unit,
373 );
374}
375
376pub fn into_contiguous_packed_ref<R: Runtime, E: Int>(
378 client: &ComputeClient<R::Server>,
379 input: &TensorHandleRef<'_, R>,
380 output: &TensorHandleRef<'_, R>,
381 shape: &[usize],
382 packing: u32,
383) {
384 let num_elems: usize = input.shape.iter().product();
385
386 let rank = input.strides.len();
388 let line_size = tensor_line_size_parallel(
389 R::io_optimized_line_sizes(&E::as_type_native_unchecked()),
390 output.shape,
391 output.strides,
392 rank - 1,
393 );
394 let num_vecs = num_elems / line_size as usize;
395 let num_sm = client
396 .properties()
397 .hardware
398 .num_streaming_multiprocessors
399 .unwrap_or(NUM_SM_APPROX);
400 let simul_vecs = num_sm * CubeDim::default().num_elems();
401 let mut elems_per_unit = match num_vecs as u32 / simul_vecs {
402 0..2 => 1,
403 2..4 => 2,
404 4..8 => 4,
405 8.. => 8,
406 };
407
408 let mut num_elems_per_unit = line_size as u32 * elems_per_unit;
409
410 let last_dim = output.shape[rank - 1];
411 let packed_dim = input
412 .strides
413 .iter()
414 .enumerate()
415 .rev()
416 .find(|(_, s)| **s == 1)
417 .expect("At least one stride should be 1")
418 .0;
419
420 while !last_dim.is_multiple_of(num_elems_per_unit as usize) {
422 elems_per_unit /= 2;
423 num_elems_per_unit /= 2;
424 }
425
426 let out_layout = LinearLayoutArgs::from_handle(client, output, line_size);
427
428 let cube_dim = CubeDim::default();
429 let cube_count =
430 calculate_cube_count_elemwise(num_elems.div_ceil(num_elems_per_unit as usize), cube_dim);
431
432 let in_shape = shape
433 .iter()
434 .map(|s| FastDivmodArgs::new(client, *s as u32))
435 .collect();
436
437 into_contiguous_kernel_packed::launch::<E, R>(
438 client,
439 cube_count,
440 cube_dim,
441 input.as_tensor_arg(1),
442 output.as_tensor_arg(line_size),
443 out_layout,
444 in_shape,
445 packed_dim as u32,
446 packing,
447 rank as u32,
448 elems_per_unit,
449 );
450}
451
452pub fn is_contiguous(shape: &[usize], strides: &[usize]) -> bool {
454 if shape.is_empty() {
455 return true;
456 }
457
458 for (expected, &stride) in compact_strides(shape).into_iter().zip(strides) {
459 if expected != stride {
460 return false;
461 }
462 }
463
464 true
465}
466
467pub fn is_contiguous_pitched(shape: &[usize], strides: &[usize]) -> bool {
471 let rank = shape.len();
472 if strides[rank - 1] != 1 {
473 return false;
474 }
475 if rank <= 1 {
476 return true;
477 }
478
479 let mut sorted = strides.to_vec();
480 sorted.sort();
481 sorted.reverse();
482
483 if sorted != strides {
484 return false;
485 }
486
487 for i in 0..rank - 2 {
488 if strides[i] != shape[i + 1] * strides[i + 1] {
489 return false;
490 }
491 }
492 true
493}
494
495pub fn compact_strides(shape: &[usize]) -> Vec<usize> {
496 let rank = shape.len();
497 let mut strides = vec![1; rank];
498 for i in (0..rank - 1).rev() {
499 strides[i] = strides[i + 1] * shape[i + 1];
500 }
501 strides
502}