cubecl_std/tensor/contiguous/
base.rs1use crate::{
2 FastDivmod, FastDivmodArgs,
3 tensor::{
4 TensorHandle, into_contiguous_ref,
5 layout::{
6 Layout, LayoutExpand,
7 linear::{LinearLayout, LinearLayoutArgs, LinearView, linear_view},
8 },
9 },
10};
11use cubecl::prelude::*;
12use cubecl_core::{
13 self as cubecl, calculate_cube_count_elemwise,
14 ir::{LineSize, StorageType},
15 tensor_line_size_parallel,
16};
17
18pub const NUM_SM_APPROX: u32 = 50;
19
20#[cube]
22pub fn index_offset_with_layout<N: CubePrimitive, L: CubePrimitive>(
23 tensor: &Tensor<Line<N>>,
24 layout: &Tensor<Line<L>>,
25 offset_layout: usize,
26 dim_start: usize,
27 dim_end: usize,
28 #[comptime] unroll: bool,
29) -> usize {
30 let offset_ref = offset_layout * tensor.line_size();
31 let mut offset = 0;
32
33 #[unroll(unroll)]
34 for i in dim_start..dim_end {
35 let ogwl = offset_ref / layout.stride(i);
36 offset += ogwl % tensor.shape(i) * tensor.stride(i);
37 }
38
39 offset / tensor.line_size()
40}
41
42#[cube]
44pub fn index_offset_contiguous<N: CubePrimitive>(
45 tensor: &Tensor<Line<N>>,
46 offset_layout: usize,
47 #[comptime] rank: Option<usize>,
48) -> usize {
49 let unroll = rank.is_some();
50 let rank = rank.unwrap_or_else(|| tensor.rank());
51
52 let offset_ref = offset_layout * tensor.line_size();
53 let mut offset = 0;
54 let mut remainder = offset_ref;
55
56 #[unroll(unroll)]
57 for i in 0..rank {
58 let dim = rank - i - 1;
59 let shape = tensor.shape(dim);
60 let ogwl = remainder % shape;
61 offset += ogwl * tensor.stride(dim);
62 remainder /= shape;
63 }
64
65 offset / tensor.line_size()
66}
67
68#[cube]
70pub fn index_offset_contiguous_fastdivmod(
71 offset: usize,
72 shape: &Sequence<FastDivmod<usize>>,
73 stride: &Sequence<usize>,
74 #[comptime] line_size: LineSize,
75) -> usize {
76 let rank = shape.len().comptime();
77
78 let offset_ref = offset * line_size;
79 let mut offset = 0;
80 let mut remainder = offset_ref;
81
82 #[unroll]
83 for i in 0..rank {
84 let dim = rank - i - 1;
85
86 let (rem, ogwl) = shape[dim].div_mod(remainder);
87 offset += ogwl * stride[dim];
88 remainder = rem;
89 }
90
91 offset / line_size
92}
93
94#[cube(launch)]
95fn into_contiguous_kernel<N: Numeric>(
96 input: &LinearView<Line<N>>,
97 output: &mut Tensor<Line<N>>,
98 out_layout: LinearLayout,
99 #[comptime] elems_per_thread: usize,
100 #[define(N)] _elem: StorageType,
101) {
102 let offset_linear = ABSOLUTE_POS * elems_per_thread;
103 let line_size = input.line_size();
104
105 let mut registers = Array::<Line<N>>::lined(elems_per_thread, line_size);
106
107 #[unroll]
108 for i in 0..elems_per_thread {
109 registers[i] = input[offset_linear + i];
110 }
111
112 let offset_output = out_layout.to_source_pos(offset_linear);
113
114 #[unroll]
115 for i in 0..elems_per_thread {
116 output[offset_output + i] = registers[i];
117 }
118}
119
120#[cube(launch)]
121fn into_contiguous_kernel_pack<N: Numeric>(
122 input: &LinearView<Line<N>>,
123 output: &mut Tensor<Line<N>>,
124 out_layout: LinearLayout,
125 #[comptime] elems_per_thread: usize,
126 #[define(N)] _elem: StorageType,
127) {
128 let line_size = output.line_size().comptime();
129 let lines_per_thread = elems_per_thread / line_size;
130
131 let offset_output = ABSOLUTE_POS * lines_per_thread;
132 let offset_input = offset_output * line_size;
133
134 let mut registers = Array::<Line<N>>::lined(lines_per_thread, line_size);
135
136 #[unroll]
137 for i in 0..lines_per_thread {
138 let offset = i * line_size;
139 let mut reg = Line::<N>::empty(line_size);
140 #[unroll]
141 for k in 0..line_size {
142 let offset_input = offset_input + offset + k;
143 reg[k] = input[offset_input][0];
144 }
145 registers[i] = reg;
146 }
147
148 let offset_output = out_layout.to_source_pos(offset_output);
149
150 #[unroll]
151 for i in 0..lines_per_thread {
152 output[offset_output + i] = registers[i];
153 }
154}
155
156#[cube]
159fn index_packed<N: Int>(
160 tensor: &Tensor<N>,
161 pos: usize,
162 in_shape: &Sequence<FastDivmod<usize>>,
163 #[comptime] packed_dim: usize,
164 #[comptime] packing: usize,
165 #[comptime] rank: usize,
166) -> N {
167 let type_size_bits = N::type_size_bits().comptime();
168 let bits_per_elem = type_size_bits / packing;
169 let mask = (1u32 << bits_per_elem) - 1;
170 let mask = N::cast_from(mask);
171
172 let elem_pos = pos * packing;
173
174 let mut out = N::new(0);
175 for n in 0..packing {
176 let mut remainder = elem_pos + n;
177 let mut offset = 0;
178 let mut packing_offset = 0;
179
180 #[unroll]
181 for i in 0..rank {
182 let dim = rank - i - 1;
183 let (rem, mut local_pos) = in_shape[dim].div_mod(remainder);
184 remainder = rem;
185 if dim == packed_dim {
186 packing_offset = local_pos % packing;
187 local_pos /= packing;
188 }
189 offset += local_pos * tensor.stride(dim);
190 }
191 let packed_val = tensor[offset];
192 let shift_in = packing_offset * bits_per_elem;
193 let shift_out = n * bits_per_elem;
194 let value = (packed_val >> N::cast_from(shift_in)) & mask;
195
196 out |= value << N::cast_from(shift_out);
197 }
198 out
199}
200
201#[cube(launch)]
202fn into_contiguous_kernel_packed<N: Int>(
203 input: &Tensor<N>,
204 output: &mut Tensor<Line<N>>,
205 out_layout: LinearLayout,
206 in_shape: Sequence<FastDivmod<usize>>,
207 #[comptime] packed_dim: usize,
208 #[comptime] packing: usize,
209 #[comptime] rank: usize,
210 #[comptime] elems_per_thread: usize,
211 #[define(N)] _elem: StorageType,
212) {
213 let line_size = output.line_size().comptime();
214 let lines_per_thread = elems_per_thread / line_size;
215
216 let offset_output = ABSOLUTE_POS * lines_per_thread;
217 let offset_input = offset_output * line_size;
218
219 if offset_output >= output.len() {
220 terminate!()
221 }
222
223 let mut registers = Array::<Line<N>>::lined(lines_per_thread, line_size);
224
225 #[unroll]
226 for i in 0..lines_per_thread {
227 let offset = i * line_size;
228 let mut reg = Line::<N>::empty(line_size);
229 #[unroll]
230 for k in 0..line_size {
231 let offset_input = offset_input + offset + k;
232
233 reg[k] = index_packed(input, offset_input, &in_shape, packed_dim, packing, rank);
234 }
235 registers[i] = reg;
236 }
237
238 let offset_output = out_layout.to_source_pos(offset_output);
239
240 #[unroll]
241 for i in 0..lines_per_thread {
242 output[offset_output + i] = registers[i];
243 }
244}
245
246pub fn into_contiguous_packed<R: Runtime>(
255 client: &ComputeClient<R>,
256 input: &TensorHandleRef<'_, R>,
257 packed_dim: usize,
258 shape: &[usize],
259 packing: usize,
260 dtype: StorageType,
261) -> Result<TensorHandle<R>, LaunchError> {
262 let rank = shape.len();
263 if rank <= 1 {
264 return into_contiguous_ref(client, input, dtype);
265 }
266
267 let mut out_shape = shape.to_vec();
268 out_shape[rank - 1] = out_shape[rank - 1].div_ceil(packing);
269 let output = TensorHandle::empty(client, out_shape, dtype);
270
271 into_contiguous_packed_ref(
274 client,
275 input,
276 &output.as_ref(),
277 packed_dim,
278 shape,
279 packing,
280 dtype,
281 )?;
282
283 Ok(output)
284}
285
286pub fn into_contiguous_gpu_ref<R: Runtime>(
288 client: &ComputeClient<R>,
289 input: &TensorHandleRef<'_, R>,
290 output: &TensorHandleRef<'_, R>,
291 dtype: StorageType,
292) -> Result<(), LaunchError> {
293 let num_elems: usize = input.shape.iter().product();
294
295 let rank = input.strides.len();
297 let line_size = tensor_line_size_parallel(
298 client.io_optimized_line_sizes(&dtype),
299 input.shape,
300 input.strides,
301 rank - 1,
302 );
303 let num_vecs = num_elems / line_size as usize;
304 let num_sm = client
305 .properties()
306 .hardware
307 .num_streaming_multiprocessors
308 .unwrap_or(NUM_SM_APPROX);
309 let cube_dim = CubeDim::new(client, num_vecs);
310 let simul_vecs = num_sm * cube_dim.num_elems();
311 let mut elems_per_unit = match num_vecs / simul_vecs as usize {
312 0..2 => 1,
313 2..4 => 2,
314 4..8 => 4,
315 8.. => 8,
316 };
317
318 let mut num_elems_per_unit = line_size as usize * elems_per_unit;
319
320 let last_dim = output.shape[rank - 1];
321
322 while !last_dim.is_multiple_of(num_elems_per_unit as usize) {
324 elems_per_unit /= 2;
325 num_elems_per_unit /= 2;
326 }
327
328 let out_vec = if line_size > 1 {
329 line_size
330 } else {
331 client
332 .io_optimized_line_sizes(&dtype)
333 .filter(|it| num_elems_per_unit.is_multiple_of(*it))
334 .max()
335 .unwrap_or(1)
336 };
337
338 let input = linear_view(client, input, line_size);
339 let out_layout = LinearLayoutArgs::from_handle(client, output, out_vec);
340
341 let cube_count = calculate_cube_count_elemwise(
342 client,
343 num_elems.div_ceil(num_elems_per_unit as usize),
344 cube_dim,
345 );
346
347 let launch = if line_size != out_vec && out_vec > 1 {
348 into_contiguous_kernel_pack::launch
349 } else {
350 into_contiguous_kernel::launch
351 };
352
353 launch(
354 client,
355 cube_count,
356 cube_dim,
357 input,
358 output.as_tensor_arg(out_vec),
359 out_layout,
360 elems_per_unit,
361 dtype,
362 )
363}
364
365pub fn into_contiguous_packed_ref<R: Runtime>(
367 client: &ComputeClient<R>,
368 input: &TensorHandleRef<'_, R>,
369 output: &TensorHandleRef<'_, R>,
370 packed_dim: usize,
371 shape: &[usize],
372 packing: usize,
373 dtype: StorageType,
374) -> Result<(), LaunchError> {
375 let num_elems: usize = input.shape.iter().product();
376
377 let rank = input.strides.len();
379 let packed_dim = rank - packed_dim - 1;
380 let line_size = tensor_line_size_parallel(
381 client.io_optimized_line_sizes(&dtype),
382 output.shape,
383 output.strides,
384 rank - 1,
385 );
386 let num_vecs = num_elems / line_size as usize;
387 let num_sm = client
388 .properties()
389 .hardware
390 .num_streaming_multiprocessors
391 .unwrap_or(NUM_SM_APPROX);
392
393 let cube_dim = CubeDim::new(client, num_vecs);
394 let simul_vecs = num_sm * cube_dim.num_elems();
395 let mut elems_per_unit = match num_vecs / simul_vecs as usize {
396 0..2 => 1,
397 2..4 => 2,
398 4..8 => 4,
399 8.. => 8,
400 };
401
402 let mut num_elems_per_unit = line_size as usize * elems_per_unit;
403
404 let last_dim = output.shape[rank - 1];
405
406 while !last_dim.is_multiple_of(num_elems_per_unit as usize) {
408 elems_per_unit /= 2;
409 num_elems_per_unit /= 2;
410 }
411
412 let out_layout = LinearLayoutArgs::from_handle(client, output, line_size);
413
414 let cube_count = calculate_cube_count_elemwise(
415 client,
416 num_elems.div_ceil(num_elems_per_unit as usize),
417 cube_dim,
418 );
419
420 let in_shape = shape
421 .iter()
422 .map(|s| FastDivmodArgs::<usize>::new(client, *s))
423 .collect();
424
425 into_contiguous_kernel_packed::launch(
426 client,
427 cube_count,
428 cube_dim,
429 input.as_tensor_arg(1),
430 output.as_tensor_arg(line_size),
431 out_layout,
432 in_shape,
433 packed_dim,
434 packing,
435 rank,
436 elems_per_unit,
437 dtype,
438 )
439}
440
441pub fn is_contiguous(shape: &[usize], strides: &[usize]) -> bool {
443 if shape.is_empty() {
444 return true;
445 }
446
447 for (expected, &stride) in compact_strides(shape).into_iter().zip(strides) {
448 if expected != stride {
449 return false;
450 }
451 }
452
453 true
454}
455
456pub fn is_contiguous_pitched(shape: &[usize], strides: &[usize]) -> bool {
460 let rank = shape.len();
461 if strides[rank - 1] != 1 {
462 return false;
463 }
464 if rank <= 1 {
465 return true;
466 }
467
468 let mut sorted = strides.to_vec();
469 sorted.sort();
470 sorted.reverse();
471
472 if sorted != strides {
473 return false;
474 }
475
476 for i in 0..rank - 2 {
477 if strides[i] != shape[i + 1] * strides[i + 1] {
478 return false;
479 }
480 }
481 true
482}
483
484pub fn compact_strides(shape: &[usize]) -> Vec<usize> {
485 let rank = shape.len();
486 let mut strides = vec![1; rank];
487 for i in (0..rank - 1).rev() {
488 strides[i] = strides[i + 1] * shape[i + 1];
489 }
490 strides
491}