Skip to main content

cubecl_std/tensor/contiguous/
base.rs

1use crate::{
2    FastDivmod,
3    tensor::{
4        TensorHandle, into_contiguous,
5        layout::{
6            Layout, LayoutExpand,
7            linear::{LinearLayout, LinearView, linear_layout, linear_view},
8        },
9    },
10};
11use cubecl::prelude::*;
12use cubecl_core::{
13    self as cubecl, calculate_cube_count_elemwise,
14    ir::{StorageType, VectorSize},
15    tensor_vector_size_parallel,
16    zspace::{Strides, strides},
17};
18
19pub const NUM_SM_APPROX: u32 = 50;
20
21/// Returns the offset of the tensor corresponding to the layout tensor.
22#[cube]
23pub fn index_offset_with_layout<T: Scalar, N1: Size, L: Scalar, N2: Size>(
24    tensor: &Tensor<Vector<T, N1>>,
25    layout: &Tensor<Vector<L, N2>>,
26    offset_layout: usize,
27    dim_start: usize,
28    dim_end: usize,
29    #[comptime] unroll: bool,
30) -> usize {
31    let offset_ref = offset_layout * tensor.vector_size();
32    let mut offset = 0;
33
34    #[unroll(unroll)]
35    for i in dim_start..dim_end {
36        let ogwl = offset_ref / layout.stride(i);
37        offset += ogwl % tensor.shape(i) * tensor.stride(i);
38    }
39
40    offset / tensor.vector_size()
41}
42
43/// Returns the offset of the tensor corresponding to a contiguous layout.
44#[cube]
45pub fn index_offset_contiguous<T: Scalar, N: Size>(
46    tensor: &Tensor<Vector<T, N>>,
47    offset_layout: usize,
48    #[comptime] rank: Option<usize>,
49) -> usize {
50    let unroll = rank.is_some();
51    let rank = rank.unwrap_or_else(|| tensor.rank());
52
53    let offset_ref = offset_layout * tensor.vector_size();
54    let mut offset = 0;
55    let mut remainder = offset_ref;
56
57    #[unroll(unroll)]
58    for i in 0..rank {
59        let dim = rank - i - 1;
60        let shape = tensor.shape(dim);
61        let ogwl = remainder % shape;
62        offset += ogwl * tensor.stride(dim);
63        remainder /= shape;
64    }
65
66    offset / tensor.vector_size()
67}
68
69/// Returns the offset of the tensor corresponding to a contiguous layout.
70#[cube]
71pub fn index_offset_contiguous_fastdivmod(
72    offset: usize,
73    shape: &Sequence<FastDivmod<usize>>,
74    stride: &Sequence<usize>,
75    #[comptime] vector_size: VectorSize,
76) -> usize {
77    let rank = shape.len().comptime();
78
79    let offset_ref = offset * vector_size;
80    let mut offset = 0;
81    let mut remainder = offset_ref;
82
83    #[unroll]
84    for i in 0..rank {
85        let dim = rank - i - 1;
86
87        let (rem, ogwl) = shape[dim].div_mod(remainder);
88        offset += ogwl * stride[dim];
89        remainder = rem;
90    }
91
92    offset / vector_size
93}
94
95#[cube(launch, address_type = "dynamic")]
96fn copy_kernel<T: Numeric, N: Size>(
97    input: &LinearView<Vector<T, N>>,
98    output: &mut Tensor<Vector<T, N>>,
99    out_layout: LinearLayout,
100    #[comptime] elems_per_thread: usize,
101    #[define(T)] _elem: StorageType,
102) {
103    let offset_linear = ABSOLUTE_POS * elems_per_thread;
104
105    let mut registers = Array::<Vector<T, N>>::new(elems_per_thread);
106
107    #[unroll]
108    for i in 0..elems_per_thread {
109        registers[i] = input[offset_linear + i];
110    }
111
112    let offset_output = out_layout.to_source_pos(offset_linear);
113
114    #[unroll]
115    for i in 0..elems_per_thread {
116        output[offset_output + i] = registers[i];
117    }
118}
119
120#[cube(launch, address_type = "dynamic")]
121fn copy_kernel_pack<T: Numeric, N: Size>(
122    input: &LinearView<T>,
123    output: &mut Tensor<Vector<T, N>>,
124    out_layout: LinearLayout,
125    #[comptime] elems_per_thread: usize,
126    #[define(T)] _elem: StorageType,
127) {
128    let vector_size = output.vector_size().comptime();
129    let vectors_per_thread = elems_per_thread / vector_size;
130
131    let offset_output = ABSOLUTE_POS * vectors_per_thread;
132    let offset_input = offset_output * vector_size;
133
134    let mut registers = Array::<Vector<T, N>>::new(vectors_per_thread);
135
136    #[unroll]
137    for i in 0..vectors_per_thread {
138        let offset = i * vector_size;
139        let mut reg = Vector::<T, N>::empty();
140        #[unroll]
141        for k in 0..vector_size {
142            let offset_input = offset_input + offset + k;
143            reg[k] = input[offset_input];
144        }
145        registers[i] = reg;
146    }
147
148    let offset_output = out_layout.to_source_pos(offset_output);
149
150    #[unroll]
151    for i in 0..vectors_per_thread {
152        output[offset_output + i] = registers[i];
153    }
154}
155
156/// Fetch all values required contained in a given position, unpack them, then repack them to their
157/// new position.
158#[cube]
159fn index_packed<N: Int>(
160    tensor: &Tensor<N>,
161    pos: usize,
162    in_shape: &Sequence<FastDivmod<usize>>,
163    #[comptime] packed_dim: usize,
164    #[comptime] packing: usize,
165    #[comptime] rank: usize,
166) -> N {
167    let type_size_bits = N::type_size_bits().comptime();
168    let bits_per_elem = type_size_bits / packing;
169    let mask = (1u32 << bits_per_elem) - 1;
170    let mask = N::cast_from(mask);
171
172    let elem_pos = pos * packing;
173
174    let mut out = N::new(0);
175    for n in 0..packing {
176        let mut remainder = elem_pos + n;
177        let mut offset = 0;
178        let mut packing_offset = 0;
179
180        #[unroll]
181        for i in 0..rank {
182            let dim = rank - i - 1;
183            let (rem, mut local_pos) = in_shape[dim].div_mod(remainder);
184            remainder = rem;
185            if dim == packed_dim {
186                packing_offset = local_pos % packing;
187                local_pos /= packing;
188            }
189            offset += local_pos * tensor.stride(dim);
190        }
191        let packed_val = tensor[offset];
192        let shift_in = packing_offset * bits_per_elem;
193        let shift_out = n * bits_per_elem;
194        let value = (packed_val >> N::cast_from(shift_in)) & mask;
195
196        out |= value << N::cast_from(shift_out);
197    }
198    out
199}
200
201#[cube(launch, address_type = "dynamic")]
202fn copy_kernel_packed<T: Int, N: Size>(
203    input: &Tensor<T>,
204    output: &mut Tensor<Vector<T, N>>,
205    out_layout: LinearLayout,
206    in_shape: Sequence<FastDivmod<usize>>,
207    #[comptime] packed_dim: usize,
208    #[comptime] packing: usize,
209    #[comptime] rank: usize,
210    #[comptime] elems_per_thread: usize,
211    #[define(T)] _elem: StorageType,
212) {
213    let vector_size = output.vector_size().comptime();
214    let vectors_per_thread = elems_per_thread / vector_size;
215
216    let offset_output = ABSOLUTE_POS * vectors_per_thread;
217    let offset_input = offset_output * vector_size;
218
219    if offset_output >= output.len() {
220        terminate!()
221    }
222
223    let mut registers = Array::<Vector<T, N>>::new(vectors_per_thread);
224
225    #[unroll]
226    for i in 0..vectors_per_thread {
227        let offset = i * vector_size;
228        let mut reg = Vector::<T, N>::empty();
229        #[unroll]
230        for k in 0..vector_size {
231            let offset_input = offset_input + offset + k;
232
233            reg[k] = index_packed(input, offset_input, &in_shape, packed_dim, packing, rank);
234        }
235        registers[i] = reg;
236    }
237
238    let offset_output = out_layout.to_source_pos(offset_output);
239
240    #[unroll]
241    for i in 0..vectors_per_thread {
242        output[offset_output + i] = registers[i];
243    }
244}
245
246/// Make a jit tensor contiguous, using the pitched allocator if available.
247/// See [`create_tensor`](cubecl_runtime::client::ComputeClient::create_tensor).
248/// Handles unpacking and repacking packed tensors (i.e. quantized values).
249/// `shape` refers to the actual (unpacked) shape of the tensor, while `packing` specifies the
250/// number of elements in each storage element.
251///
252/// # Warning
253/// This assumes `u32` or `u8` packing.
254pub fn into_contiguous_packed<R: Runtime>(
255    client: &ComputeClient<R>,
256    input: TensorBinding<R>,
257    packed_dim: usize,
258    shape: &[usize],
259    packing: usize,
260    dtype: StorageType,
261) -> TensorHandle<R> {
262    let rank = shape.len();
263    if rank <= 1 {
264        return into_contiguous(client, input, dtype);
265    }
266
267    let mut out_shape = shape.to_vec();
268    out_shape[rank - 1] = out_shape[rank - 1].div_ceil(packing);
269    let output = TensorHandle::empty(client, out_shape, dtype);
270
271    // Should reinterpret as u8 if possible at some point, but requires modifying shape/strides so
272    // keep it simple for now
273    into_contiguous_packed_ref(
274        client,
275        input,
276        output.clone().binding(),
277        packed_dim,
278        shape,
279        packing,
280        dtype,
281    );
282
283    output
284}
285
286/// Make a jit tensor contiguous.
287pub fn copy_gpu_ref<R: Runtime>(
288    client: &ComputeClient<R>,
289    input: TensorBinding<R>,
290    output: TensorBinding<R>,
291    dtype: StorageType,
292) {
293    let num_elems: usize = input.shape.iter().product();
294
295    // Vectorization is only enabled when the last dimension is contiguous.
296    let in_rank = input.strides.len();
297    let out_rank = output.strides.len();
298    let vector_size_in = tensor_vector_size_parallel(
299        client.io_optimized_vector_sizes(dtype.size()),
300        &input.shape,
301        &input.strides,
302        in_rank - 1,
303    );
304    let vector_size_out = tensor_vector_size_parallel(
305        client.io_optimized_vector_sizes(dtype.size()),
306        &output.shape,
307        &output.strides,
308        out_rank - 1,
309    );
310    let vector_size = vector_size_in.min(vector_size_out);
311
312    let num_vecs = num_elems / vector_size as usize;
313    let num_sm = client
314        .properties()
315        .hardware
316        .num_streaming_multiprocessors
317        .unwrap_or(NUM_SM_APPROX);
318    let cube_dim = CubeDim::new(client, num_vecs);
319    let simul_vecs = num_sm * cube_dim.num_elems();
320    let mut elems_per_unit = match num_vecs / simul_vecs as usize {
321        0..2 => 1,
322        2..4 => 2,
323        4..8 => 4,
324        8.. => 8,
325    };
326
327    let mut num_elems_per_unit = vector_size as usize * elems_per_unit;
328
329    let last_dim = output.shape[out_rank - 1];
330
331    // If tensor is strided, elems_per_unit must be compatible with last dim
332    while !last_dim.is_multiple_of(num_elems_per_unit as usize) {
333        elems_per_unit /= 2;
334        num_elems_per_unit /= 2;
335    }
336
337    let out_vec = if vector_size > 1 {
338        vector_size
339    } else {
340        // Recompute because it needs to account for `num_elems_per_unit`
341        client
342            .io_optimized_vector_sizes(dtype.size())
343            .filter(|it| num_elems_per_unit.is_multiple_of(*it))
344            .max()
345            .unwrap_or(1)
346    };
347
348    let address_type = input
349        .required_address_type(dtype.size())
350        .max(output.required_address_type(dtype.size()));
351    let input = linear_view(input);
352    let out_layout = linear_layout(&output, out_vec);
353
354    let cube_count = calculate_cube_count_elemwise(
355        client,
356        num_elems.div_ceil(num_elems_per_unit as usize),
357        cube_dim,
358    );
359
360    let launch = if vector_size != out_vec && out_vec > 1 {
361        copy_kernel_pack::launch
362    } else {
363        copy_kernel::launch
364    };
365
366    launch(
367        client,
368        cube_count,
369        cube_dim,
370        address_type,
371        out_vec,
372        input,
373        output.into_tensor_arg(),
374        out_layout,
375        elems_per_unit,
376        dtype,
377    )
378}
379
380/// Make a jit tensor contiguous.
381pub fn into_contiguous_packed_ref<R: Runtime>(
382    client: &ComputeClient<R>,
383    input: TensorBinding<R>,
384    output: TensorBinding<R>,
385    packed_dim: usize,
386    shape: &[usize],
387    packing: usize,
388    dtype: StorageType,
389) {
390    let num_elems: usize = input.shape.iter().product();
391
392    // Vectorization is only enabled when the last dimension is contiguous.
393    let in_rank = input.strides.len();
394    let out_rank = output.strides.len();
395    let in_packed_dim = in_rank - packed_dim - 1;
396    let vector_size = tensor_vector_size_parallel(
397        client.io_optimized_vector_sizes(dtype.size()),
398        &output.shape,
399        &output.strides,
400        out_rank - 1,
401    );
402    let num_vecs = num_elems / vector_size as usize;
403    let num_sm = client
404        .properties()
405        .hardware
406        .num_streaming_multiprocessors
407        .unwrap_or(NUM_SM_APPROX);
408
409    let cube_dim = CubeDim::new(client, num_vecs);
410    let simul_vecs = num_sm * cube_dim.num_elems();
411    let mut elems_per_unit = match num_vecs / simul_vecs as usize {
412        0..2 => 1,
413        2..4 => 2,
414        4..8 => 4,
415        8.. => 8,
416    };
417
418    let mut num_elems_per_unit = vector_size as usize * elems_per_unit;
419
420    let last_dim = output.shape[out_rank - 1];
421
422    // If tensor is strided, elems_per_unit must be compatible with last dim
423    while !last_dim.is_multiple_of(num_elems_per_unit as usize) {
424        elems_per_unit /= 2;
425        num_elems_per_unit /= 2;
426    }
427
428    let out_layout = linear_layout(&output, vector_size);
429
430    let address_type = input
431        .required_address_type(dtype.size())
432        .max(output.required_address_type(dtype.size()));
433    let cube_count = calculate_cube_count_elemwise(
434        client,
435        num_elems.div_ceil(num_elems_per_unit as usize),
436        cube_dim,
437    );
438
439    let in_shape = shape.iter().copied().collect();
440
441    copy_kernel_packed::launch(
442        client,
443        cube_count,
444        cube_dim,
445        address_type,
446        vector_size,
447        input.into_tensor_arg(),
448        output.into_tensor_arg(),
449        out_layout,
450        in_shape,
451        in_packed_dim,
452        packing,
453        in_rank,
454        elems_per_unit,
455        dtype,
456    )
457}
458
459/// Checks if the tensor associated with the given shape and strides is contiguous.
460pub fn is_contiguous(shape: &[usize], strides: &[usize]) -> bool {
461    if shape.is_empty() {
462        return true;
463    }
464
465    for (&expected, &stride) in compact_strides(shape).iter().zip(strides) {
466        if expected != stride {
467            return false;
468        }
469    }
470
471    true
472}
473
474/// Checks if a tensor is only strided on the last dimension, and could be safely reinterpreted as
475/// a 2D tensor with unit stride on the last dimension. This will always hold for non-permuted
476/// tensors allocated on a runtime.
477pub fn is_contiguous_pitched(shape: &[usize], strides: &[usize]) -> bool {
478    let rank = shape.len();
479    if strides[rank - 1] != 1 {
480        return false;
481    }
482    if rank <= 1 {
483        return true;
484    }
485
486    let mut sorted = strides.to_vec();
487    sorted.sort();
488    sorted.reverse();
489
490    if sorted != strides {
491        return false;
492    }
493
494    for i in 0..rank - 2 {
495        if strides[i] != shape[i + 1] * strides[i + 1] {
496            return false;
497        }
498    }
499    true
500}
501
502pub fn compact_strides(shape: &[usize]) -> Strides {
503    let rank = shape.len();
504    let mut strides = strides![1; rank];
505    for i in (0..rank - 1).rev() {
506        strides[i] = strides[i + 1] * shape[i + 1];
507    }
508    strides
509}