1use crate::{
2 FastDivmod, FastDivmodArgs,
3 tensor::{
4 TensorHandle, into_contiguous_ref,
5 layout::{
6 Layout, LayoutExpand,
7 linear::{LinearLayout, LinearLayoutArgs, LinearView, linear_view},
8 },
9 },
10};
11use cubecl::prelude::*;
12use cubecl_core::{
13 self as cubecl, calculate_cube_count_elemwise,
14 ir::{LineSize, StorageType},
15 tensor_line_size_parallel,
16 zspace::{Strides, strides},
17};
18
19pub const NUM_SM_APPROX: u32 = 50;
20
21#[cube]
23pub fn index_offset_with_layout<N: CubePrimitive, L: CubePrimitive>(
24 tensor: &Tensor<Line<N>>,
25 layout: &Tensor<Line<L>>,
26 offset_layout: usize,
27 dim_start: usize,
28 dim_end: usize,
29 #[comptime] unroll: bool,
30) -> usize {
31 let offset_ref = offset_layout * tensor.line_size();
32 let mut offset = 0;
33
34 #[unroll(unroll)]
35 for i in dim_start..dim_end {
36 let ogwl = offset_ref / layout.stride(i);
37 offset += ogwl % tensor.shape(i) * tensor.stride(i);
38 }
39
40 offset / tensor.line_size()
41}
42
43#[cube]
45pub fn index_offset_contiguous<N: CubePrimitive>(
46 tensor: &Tensor<Line<N>>,
47 offset_layout: usize,
48 #[comptime] rank: Option<usize>,
49) -> usize {
50 let unroll = rank.is_some();
51 let rank = rank.unwrap_or_else(|| tensor.rank());
52
53 let offset_ref = offset_layout * tensor.line_size();
54 let mut offset = 0;
55 let mut remainder = offset_ref;
56
57 #[unroll(unroll)]
58 for i in 0..rank {
59 let dim = rank - i - 1;
60 let shape = tensor.shape(dim);
61 let ogwl = remainder % shape;
62 offset += ogwl * tensor.stride(dim);
63 remainder /= shape;
64 }
65
66 offset / tensor.line_size()
67}
68
69#[cube]
71pub fn index_offset_contiguous_fastdivmod(
72 offset: usize,
73 shape: &Sequence<FastDivmod<usize>>,
74 stride: &Sequence<usize>,
75 #[comptime] line_size: LineSize,
76) -> usize {
77 let rank = shape.len().comptime();
78
79 let offset_ref = offset * line_size;
80 let mut offset = 0;
81 let mut remainder = offset_ref;
82
83 #[unroll]
84 for i in 0..rank {
85 let dim = rank - i - 1;
86
87 let (rem, ogwl) = shape[dim].div_mod(remainder);
88 offset += ogwl * stride[dim];
89 remainder = rem;
90 }
91
92 offset / line_size
93}
94
95#[cube(launch, address_type = "dynamic")]
96fn copy_kernel<N: Numeric>(
97 input: &LinearView<Line<N>>,
98 output: &mut Tensor<Line<N>>,
99 out_layout: LinearLayout,
100 #[comptime] elems_per_thread: usize,
101 #[define(N)] _elem: StorageType,
102) {
103 let offset_linear = ABSOLUTE_POS * elems_per_thread;
104 let line_size = input.line_size();
105
106 let mut registers = Array::<Line<N>>::lined(elems_per_thread, line_size);
107
108 #[unroll]
109 for i in 0..elems_per_thread {
110 registers[i] = input[offset_linear + i];
111 }
112
113 let offset_output = out_layout.to_source_pos(offset_linear);
114
115 #[unroll]
116 for i in 0..elems_per_thread {
117 output[offset_output + i] = registers[i];
118 }
119}
120
121#[cube(launch, address_type = "dynamic")]
122fn copy_kernel_pack<N: Numeric>(
123 input: &LinearView<Line<N>>,
124 output: &mut Tensor<Line<N>>,
125 out_layout: LinearLayout,
126 #[comptime] elems_per_thread: usize,
127 #[define(N)] _elem: StorageType,
128) {
129 let line_size = output.line_size().comptime();
130 let lines_per_thread = elems_per_thread / line_size;
131
132 let offset_output = ABSOLUTE_POS * lines_per_thread;
133 let offset_input = offset_output * line_size;
134
135 let mut registers = Array::<Line<N>>::lined(lines_per_thread, line_size);
136
137 #[unroll]
138 for i in 0..lines_per_thread {
139 let offset = i * line_size;
140 let mut reg = Line::<N>::empty(line_size);
141 #[unroll]
142 for k in 0..line_size {
143 let offset_input = offset_input + offset + k;
144 reg[k] = input[offset_input][0];
145 }
146 registers[i] = reg;
147 }
148
149 let offset_output = out_layout.to_source_pos(offset_output);
150
151 #[unroll]
152 for i in 0..lines_per_thread {
153 output[offset_output + i] = registers[i];
154 }
155}
156
157#[cube]
160fn index_packed<N: Int>(
161 tensor: &Tensor<N>,
162 pos: usize,
163 in_shape: &Sequence<FastDivmod<usize>>,
164 #[comptime] packed_dim: usize,
165 #[comptime] packing: usize,
166 #[comptime] rank: usize,
167) -> N {
168 let type_size_bits = N::type_size_bits().comptime();
169 let bits_per_elem = type_size_bits / packing;
170 let mask = (1u32 << bits_per_elem) - 1;
171 let mask = N::cast_from(mask);
172
173 let elem_pos = pos * packing;
174
175 let mut out = N::new(0);
176 for n in 0..packing {
177 let mut remainder = elem_pos + n;
178 let mut offset = 0;
179 let mut packing_offset = 0;
180
181 #[unroll]
182 for i in 0..rank {
183 let dim = rank - i - 1;
184 let (rem, mut local_pos) = in_shape[dim].div_mod(remainder);
185 remainder = rem;
186 if dim == packed_dim {
187 packing_offset = local_pos % packing;
188 local_pos /= packing;
189 }
190 offset += local_pos * tensor.stride(dim);
191 }
192 let packed_val = tensor[offset];
193 let shift_in = packing_offset * bits_per_elem;
194 let shift_out = n * bits_per_elem;
195 let value = (packed_val >> N::cast_from(shift_in)) & mask;
196
197 out |= value << N::cast_from(shift_out);
198 }
199 out
200}
201
202#[cube(launch, address_type = "dynamic")]
203fn copy_kernel_packed<N: Int>(
204 input: &Tensor<N>,
205 output: &mut Tensor<Line<N>>,
206 out_layout: LinearLayout,
207 in_shape: Sequence<FastDivmod<usize>>,
208 #[comptime] packed_dim: usize,
209 #[comptime] packing: usize,
210 #[comptime] rank: usize,
211 #[comptime] elems_per_thread: usize,
212 #[define(N)] _elem: StorageType,
213) {
214 let line_size = output.line_size().comptime();
215 let lines_per_thread = elems_per_thread / line_size;
216
217 let offset_output = ABSOLUTE_POS * lines_per_thread;
218 let offset_input = offset_output * line_size;
219
220 if offset_output >= output.len() {
221 terminate!()
222 }
223
224 let mut registers = Array::<Line<N>>::lined(lines_per_thread, line_size);
225
226 #[unroll]
227 for i in 0..lines_per_thread {
228 let offset = i * line_size;
229 let mut reg = Line::<N>::empty(line_size);
230 #[unroll]
231 for k in 0..line_size {
232 let offset_input = offset_input + offset + k;
233
234 reg[k] = index_packed(input, offset_input, &in_shape, packed_dim, packing, rank);
235 }
236 registers[i] = reg;
237 }
238
239 let offset_output = out_layout.to_source_pos(offset_output);
240
241 #[unroll]
242 for i in 0..lines_per_thread {
243 output[offset_output + i] = registers[i];
244 }
245}
246
247pub fn into_contiguous_packed<R: Runtime>(
256 client: &ComputeClient<R>,
257 input: &TensorHandleRef<'_, R>,
258 packed_dim: usize,
259 shape: &[usize],
260 packing: usize,
261 dtype: StorageType,
262) -> Result<TensorHandle<R>, LaunchError> {
263 let rank = shape.len();
264 if rank <= 1 {
265 return into_contiguous_ref(client, input, dtype);
266 }
267
268 let mut out_shape = shape.to_vec();
269 out_shape[rank - 1] = out_shape[rank - 1].div_ceil(packing);
270 let output = TensorHandle::empty(client, out_shape, dtype);
271
272 into_contiguous_packed_ref(
275 client,
276 input,
277 &output.as_ref(),
278 packed_dim,
279 shape,
280 packing,
281 dtype,
282 )?;
283
284 Ok(output)
285}
286
287pub fn copy_gpu_ref<R: Runtime>(
289 client: &ComputeClient<R>,
290 input: &TensorHandleRef<'_, R>,
291 output: &TensorHandleRef<'_, R>,
292 dtype: StorageType,
293) -> Result<(), LaunchError> {
294 let num_elems: usize = input.shape.iter().product();
295
296 let in_rank = input.strides.len();
298 let out_rank = output.strides.len();
299 let line_size_in = tensor_line_size_parallel(
300 client.io_optimized_line_sizes(dtype.size()),
301 input.shape,
302 input.strides,
303 in_rank - 1,
304 );
305 let line_size_out = tensor_line_size_parallel(
306 client.io_optimized_line_sizes(dtype.size()),
307 output.shape,
308 output.strides,
309 out_rank - 1,
310 );
311 let line_size = line_size_in.min(line_size_out);
312
313 let num_vecs = num_elems / line_size as usize;
314 let num_sm = client
315 .properties()
316 .hardware
317 .num_streaming_multiprocessors
318 .unwrap_or(NUM_SM_APPROX);
319 let cube_dim = CubeDim::new(client, num_vecs);
320 let simul_vecs = num_sm * cube_dim.num_elems();
321 let mut elems_per_unit = match num_vecs / simul_vecs as usize {
322 0..2 => 1,
323 2..4 => 2,
324 4..8 => 4,
325 8.. => 8,
326 };
327
328 let mut num_elems_per_unit = line_size as usize * elems_per_unit;
329
330 let last_dim = output.shape[out_rank - 1];
331
332 while !last_dim.is_multiple_of(num_elems_per_unit as usize) {
334 elems_per_unit /= 2;
335 num_elems_per_unit /= 2;
336 }
337
338 let out_vec = if line_size > 1 {
339 line_size
340 } else {
341 client
343 .io_optimized_line_sizes(dtype.size())
344 .filter(|it| num_elems_per_unit.is_multiple_of(*it))
345 .max()
346 .unwrap_or(1)
347 };
348
349 let address_type = input
350 .required_address_type()
351 .max(output.required_address_type());
352 let input = linear_view(client, input, line_size);
353 let out_layout = LinearLayoutArgs::from_handle(client, output, out_vec);
354
355 let cube_count = calculate_cube_count_elemwise(
356 client,
357 num_elems.div_ceil(num_elems_per_unit as usize),
358 cube_dim,
359 );
360
361 let launch = if line_size != out_vec && out_vec > 1 {
362 copy_kernel_pack::launch
363 } else {
364 copy_kernel::launch
365 };
366
367 launch(
368 client,
369 cube_count,
370 cube_dim,
371 address_type,
372 input,
373 output.as_tensor_arg(out_vec),
374 out_layout,
375 elems_per_unit,
376 dtype,
377 )
378}
379
380pub fn into_contiguous_packed_ref<R: Runtime>(
382 client: &ComputeClient<R>,
383 input: &TensorHandleRef<'_, R>,
384 output: &TensorHandleRef<'_, R>,
385 packed_dim: usize,
386 shape: &[usize],
387 packing: usize,
388 dtype: StorageType,
389) -> Result<(), LaunchError> {
390 let num_elems: usize = input.shape.iter().product();
391
392 let in_rank = input.strides.len();
394 let out_rank = output.strides.len();
395 let in_packed_dim = in_rank - packed_dim - 1;
396 let line_size = tensor_line_size_parallel(
397 client.io_optimized_line_sizes(dtype.size()),
398 output.shape,
399 output.strides,
400 out_rank - 1,
401 );
402 let num_vecs = num_elems / line_size as usize;
403 let num_sm = client
404 .properties()
405 .hardware
406 .num_streaming_multiprocessors
407 .unwrap_or(NUM_SM_APPROX);
408
409 let cube_dim = CubeDim::new(client, num_vecs);
410 let simul_vecs = num_sm * cube_dim.num_elems();
411 let mut elems_per_unit = match num_vecs / simul_vecs as usize {
412 0..2 => 1,
413 2..4 => 2,
414 4..8 => 4,
415 8.. => 8,
416 };
417
418 let mut num_elems_per_unit = line_size as usize * elems_per_unit;
419
420 let last_dim = output.shape[out_rank - 1];
421
422 while !last_dim.is_multiple_of(num_elems_per_unit as usize) {
424 elems_per_unit /= 2;
425 num_elems_per_unit /= 2;
426 }
427
428 let out_layout = LinearLayoutArgs::from_handle(client, output, line_size);
429
430 let address_type = input
431 .required_address_type()
432 .max(output.required_address_type());
433 let cube_count = calculate_cube_count_elemwise(
434 client,
435 num_elems.div_ceil(num_elems_per_unit as usize),
436 cube_dim,
437 );
438
439 let in_shape = shape
440 .iter()
441 .map(|s| FastDivmodArgs::<usize>::new(client, *s))
442 .collect();
443
444 copy_kernel_packed::launch(
445 client,
446 cube_count,
447 cube_dim,
448 address_type,
449 input.as_tensor_arg(1),
450 output.as_tensor_arg(line_size),
451 out_layout,
452 in_shape,
453 in_packed_dim,
454 packing,
455 in_rank,
456 elems_per_unit,
457 dtype,
458 )
459}
460
461pub fn is_contiguous(shape: &[usize], strides: &[usize]) -> bool {
463 if shape.is_empty() {
464 return true;
465 }
466
467 for (&expected, &stride) in compact_strides(shape).iter().zip(strides) {
468 if expected != stride {
469 return false;
470 }
471 }
472
473 true
474}
475
476pub fn is_contiguous_pitched(shape: &[usize], strides: &[usize]) -> bool {
480 let rank = shape.len();
481 if strides[rank - 1] != 1 {
482 return false;
483 }
484 if rank <= 1 {
485 return true;
486 }
487
488 let mut sorted = strides.to_vec();
489 sorted.sort();
490 sorted.reverse();
491
492 if sorted != strides {
493 return false;
494 }
495
496 for i in 0..rank - 2 {
497 if strides[i] != shape[i + 1] * strides[i + 1] {
498 return false;
499 }
500 }
501 true
502}
503
504pub fn compact_strides(shape: &[usize]) -> Strides {
505 let rank = shape.len();
506 let mut strides = strides![1; rank];
507 for i in (0..rank - 1).rev() {
508 strides[i] = strides[i + 1] * shape[i + 1];
509 }
510 strides
511}