Skip to main content

cubecl_std/tensor/contiguous/
base.rs

1use crate::{
2    FastDivmod, FastDivmodArgs,
3    tensor::{
4        TensorHandle, into_contiguous_ref,
5        layout::{
6            Layout, LayoutExpand,
7            linear::{LinearLayout, LinearLayoutArgs, LinearView, linear_view},
8        },
9    },
10};
11use cubecl::prelude::*;
12use cubecl_core::{
13    self as cubecl, calculate_cube_count_elemwise,
14    ir::{LineSize, StorageType},
15    tensor_line_size_parallel,
16    zspace::{Strides, strides},
17};
18
19pub const NUM_SM_APPROX: u32 = 50;
20
21/// Returns the offset of the tensor corresponding to the layout tensor.
22#[cube]
23pub fn index_offset_with_layout<N: CubePrimitive, L: CubePrimitive>(
24    tensor: &Tensor<Line<N>>,
25    layout: &Tensor<Line<L>>,
26    offset_layout: usize,
27    dim_start: usize,
28    dim_end: usize,
29    #[comptime] unroll: bool,
30) -> usize {
31    let offset_ref = offset_layout * tensor.line_size();
32    let mut offset = 0;
33
34    #[unroll(unroll)]
35    for i in dim_start..dim_end {
36        let ogwl = offset_ref / layout.stride(i);
37        offset += ogwl % tensor.shape(i) * tensor.stride(i);
38    }
39
40    offset / tensor.line_size()
41}
42
43/// Returns the offset of the tensor corresponding to a contiguous layout.
44#[cube]
45pub fn index_offset_contiguous<N: CubePrimitive>(
46    tensor: &Tensor<Line<N>>,
47    offset_layout: usize,
48    #[comptime] rank: Option<usize>,
49) -> usize {
50    let unroll = rank.is_some();
51    let rank = rank.unwrap_or_else(|| tensor.rank());
52
53    let offset_ref = offset_layout * tensor.line_size();
54    let mut offset = 0;
55    let mut remainder = offset_ref;
56
57    #[unroll(unroll)]
58    for i in 0..rank {
59        let dim = rank - i - 1;
60        let shape = tensor.shape(dim);
61        let ogwl = remainder % shape;
62        offset += ogwl * tensor.stride(dim);
63        remainder /= shape;
64    }
65
66    offset / tensor.line_size()
67}
68
69/// Returns the offset of the tensor corresponding to a contiguous layout.
70#[cube]
71pub fn index_offset_contiguous_fastdivmod(
72    offset: usize,
73    shape: &Sequence<FastDivmod<usize>>,
74    stride: &Sequence<usize>,
75    #[comptime] line_size: LineSize,
76) -> usize {
77    let rank = shape.len().comptime();
78
79    let offset_ref = offset * line_size;
80    let mut offset = 0;
81    let mut remainder = offset_ref;
82
83    #[unroll]
84    for i in 0..rank {
85        let dim = rank - i - 1;
86
87        let (rem, ogwl) = shape[dim].div_mod(remainder);
88        offset += ogwl * stride[dim];
89        remainder = rem;
90    }
91
92    offset / line_size
93}
94
95#[cube(launch, address_type = "dynamic")]
96fn copy_kernel<N: Numeric>(
97    input: &LinearView<Line<N>>,
98    output: &mut Tensor<Line<N>>,
99    out_layout: LinearLayout,
100    #[comptime] elems_per_thread: usize,
101    #[define(N)] _elem: StorageType,
102) {
103    let offset_linear = ABSOLUTE_POS * elems_per_thread;
104    let line_size = input.line_size();
105
106    let mut registers = Array::<Line<N>>::lined(elems_per_thread, line_size);
107
108    #[unroll]
109    for i in 0..elems_per_thread {
110        registers[i] = input[offset_linear + i];
111    }
112
113    let offset_output = out_layout.to_source_pos(offset_linear);
114
115    #[unroll]
116    for i in 0..elems_per_thread {
117        output[offset_output + i] = registers[i];
118    }
119}
120
121#[cube(launch, address_type = "dynamic")]
122fn copy_kernel_pack<N: Numeric>(
123    input: &LinearView<Line<N>>,
124    output: &mut Tensor<Line<N>>,
125    out_layout: LinearLayout,
126    #[comptime] elems_per_thread: usize,
127    #[define(N)] _elem: StorageType,
128) {
129    let line_size = output.line_size().comptime();
130    let lines_per_thread = elems_per_thread / line_size;
131
132    let offset_output = ABSOLUTE_POS * lines_per_thread;
133    let offset_input = offset_output * line_size;
134
135    let mut registers = Array::<Line<N>>::lined(lines_per_thread, line_size);
136
137    #[unroll]
138    for i in 0..lines_per_thread {
139        let offset = i * line_size;
140        let mut reg = Line::<N>::empty(line_size);
141        #[unroll]
142        for k in 0..line_size {
143            let offset_input = offset_input + offset + k;
144            reg[k] = input[offset_input][0];
145        }
146        registers[i] = reg;
147    }
148
149    let offset_output = out_layout.to_source_pos(offset_output);
150
151    #[unroll]
152    for i in 0..lines_per_thread {
153        output[offset_output + i] = registers[i];
154    }
155}
156
157/// Fetch all values required contained in a given position, unpack them, then repack them to their
158/// new position.
159#[cube]
160fn index_packed<N: Int>(
161    tensor: &Tensor<N>,
162    pos: usize,
163    in_shape: &Sequence<FastDivmod<usize>>,
164    #[comptime] packed_dim: usize,
165    #[comptime] packing: usize,
166    #[comptime] rank: usize,
167) -> N {
168    let type_size_bits = N::type_size_bits().comptime();
169    let bits_per_elem = type_size_bits / packing;
170    let mask = (1u32 << bits_per_elem) - 1;
171    let mask = N::cast_from(mask);
172
173    let elem_pos = pos * packing;
174
175    let mut out = N::new(0);
176    for n in 0..packing {
177        let mut remainder = elem_pos + n;
178        let mut offset = 0;
179        let mut packing_offset = 0;
180
181        #[unroll]
182        for i in 0..rank {
183            let dim = rank - i - 1;
184            let (rem, mut local_pos) = in_shape[dim].div_mod(remainder);
185            remainder = rem;
186            if dim == packed_dim {
187                packing_offset = local_pos % packing;
188                local_pos /= packing;
189            }
190            offset += local_pos * tensor.stride(dim);
191        }
192        let packed_val = tensor[offset];
193        let shift_in = packing_offset * bits_per_elem;
194        let shift_out = n * bits_per_elem;
195        let value = (packed_val >> N::cast_from(shift_in)) & mask;
196
197        out |= value << N::cast_from(shift_out);
198    }
199    out
200}
201
202#[cube(launch, address_type = "dynamic")]
203fn copy_kernel_packed<N: Int>(
204    input: &Tensor<N>,
205    output: &mut Tensor<Line<N>>,
206    out_layout: LinearLayout,
207    in_shape: Sequence<FastDivmod<usize>>,
208    #[comptime] packed_dim: usize,
209    #[comptime] packing: usize,
210    #[comptime] rank: usize,
211    #[comptime] elems_per_thread: usize,
212    #[define(N)] _elem: StorageType,
213) {
214    let line_size = output.line_size().comptime();
215    let lines_per_thread = elems_per_thread / line_size;
216
217    let offset_output = ABSOLUTE_POS * lines_per_thread;
218    let offset_input = offset_output * line_size;
219
220    if offset_output >= output.len() {
221        terminate!()
222    }
223
224    let mut registers = Array::<Line<N>>::lined(lines_per_thread, line_size);
225
226    #[unroll]
227    for i in 0..lines_per_thread {
228        let offset = i * line_size;
229        let mut reg = Line::<N>::empty(line_size);
230        #[unroll]
231        for k in 0..line_size {
232            let offset_input = offset_input + offset + k;
233
234            reg[k] = index_packed(input, offset_input, &in_shape, packed_dim, packing, rank);
235        }
236        registers[i] = reg;
237    }
238
239    let offset_output = out_layout.to_source_pos(offset_output);
240
241    #[unroll]
242    for i in 0..lines_per_thread {
243        output[offset_output + i] = registers[i];
244    }
245}
246
247/// Make a jit tensor contiguous, using the pitched allocator if available.
248/// See [`create_tensor`](cubecl_runtime::client::ComputeClient::create_tensor).
249/// Handles unpacking and repacking packed tensors (i.e. quantized values).
250/// `shape` refers to the actual (unpacked) shape of the tensor, while `packing` specifies the
251/// number of elements in each storage element.
252///
253/// # Warning
254/// This assumes `u32` or `u8` packing.
255pub fn into_contiguous_packed<R: Runtime>(
256    client: &ComputeClient<R>,
257    input: &TensorHandleRef<'_, R>,
258    packed_dim: usize,
259    shape: &[usize],
260    packing: usize,
261    dtype: StorageType,
262) -> Result<TensorHandle<R>, LaunchError> {
263    let rank = shape.len();
264    if rank <= 1 {
265        return into_contiguous_ref(client, input, dtype);
266    }
267
268    let mut out_shape = shape.to_vec();
269    out_shape[rank - 1] = out_shape[rank - 1].div_ceil(packing);
270    let output = TensorHandle::empty(client, out_shape, dtype);
271
272    // Should reinterpret as u8 if possible at some point, but requires modifying shape/strides so
273    // keep it simple for now
274    into_contiguous_packed_ref(
275        client,
276        input,
277        &output.as_ref(),
278        packed_dim,
279        shape,
280        packing,
281        dtype,
282    )?;
283
284    Ok(output)
285}
286
287/// Make a jit tensor contiguous.
288pub fn copy_gpu_ref<R: Runtime>(
289    client: &ComputeClient<R>,
290    input: &TensorHandleRef<'_, R>,
291    output: &TensorHandleRef<'_, R>,
292    dtype: StorageType,
293) -> Result<(), LaunchError> {
294    let num_elems: usize = input.shape.iter().product();
295
296    // Vectorization is only enabled when the last dimension is contiguous.
297    let in_rank = input.strides.len();
298    let out_rank = output.strides.len();
299    let line_size_in = tensor_line_size_parallel(
300        client.io_optimized_line_sizes(dtype.size()),
301        input.shape,
302        input.strides,
303        in_rank - 1,
304    );
305    let line_size_out = tensor_line_size_parallel(
306        client.io_optimized_line_sizes(dtype.size()),
307        output.shape,
308        output.strides,
309        out_rank - 1,
310    );
311    let line_size = line_size_in.min(line_size_out);
312
313    let num_vecs = num_elems / line_size as usize;
314    let num_sm = client
315        .properties()
316        .hardware
317        .num_streaming_multiprocessors
318        .unwrap_or(NUM_SM_APPROX);
319    let cube_dim = CubeDim::new(client, num_vecs);
320    let simul_vecs = num_sm * cube_dim.num_elems();
321    let mut elems_per_unit = match num_vecs / simul_vecs as usize {
322        0..2 => 1,
323        2..4 => 2,
324        4..8 => 4,
325        8.. => 8,
326    };
327
328    let mut num_elems_per_unit = line_size as usize * elems_per_unit;
329
330    let last_dim = output.shape[out_rank - 1];
331
332    // If tensor is strided, elems_per_unit must be compatible with last dim
333    while !last_dim.is_multiple_of(num_elems_per_unit as usize) {
334        elems_per_unit /= 2;
335        num_elems_per_unit /= 2;
336    }
337
338    let out_vec = if line_size > 1 {
339        line_size
340    } else {
341        // Recompute because it needs to account for `num_elems_per_unit`
342        client
343            .io_optimized_line_sizes(dtype.size())
344            .filter(|it| num_elems_per_unit.is_multiple_of(*it))
345            .max()
346            .unwrap_or(1)
347    };
348
349    let address_type = input
350        .required_address_type()
351        .max(output.required_address_type());
352    let input = linear_view(client, input, line_size);
353    let out_layout = LinearLayoutArgs::from_handle(client, output, out_vec);
354
355    let cube_count = calculate_cube_count_elemwise(
356        client,
357        num_elems.div_ceil(num_elems_per_unit as usize),
358        cube_dim,
359    );
360
361    let launch = if line_size != out_vec && out_vec > 1 {
362        copy_kernel_pack::launch
363    } else {
364        copy_kernel::launch
365    };
366
367    launch(
368        client,
369        cube_count,
370        cube_dim,
371        address_type,
372        input,
373        output.as_tensor_arg(out_vec),
374        out_layout,
375        elems_per_unit,
376        dtype,
377    )
378}
379
380/// Make a jit tensor contiguous.
381pub fn into_contiguous_packed_ref<R: Runtime>(
382    client: &ComputeClient<R>,
383    input: &TensorHandleRef<'_, R>,
384    output: &TensorHandleRef<'_, R>,
385    packed_dim: usize,
386    shape: &[usize],
387    packing: usize,
388    dtype: StorageType,
389) -> Result<(), LaunchError> {
390    let num_elems: usize = input.shape.iter().product();
391
392    // Vectorization is only enabled when the last dimension is contiguous.
393    let in_rank = input.strides.len();
394    let out_rank = output.strides.len();
395    let in_packed_dim = in_rank - packed_dim - 1;
396    let line_size = tensor_line_size_parallel(
397        client.io_optimized_line_sizes(dtype.size()),
398        output.shape,
399        output.strides,
400        out_rank - 1,
401    );
402    let num_vecs = num_elems / line_size as usize;
403    let num_sm = client
404        .properties()
405        .hardware
406        .num_streaming_multiprocessors
407        .unwrap_or(NUM_SM_APPROX);
408
409    let cube_dim = CubeDim::new(client, num_vecs);
410    let simul_vecs = num_sm * cube_dim.num_elems();
411    let mut elems_per_unit = match num_vecs / simul_vecs as usize {
412        0..2 => 1,
413        2..4 => 2,
414        4..8 => 4,
415        8.. => 8,
416    };
417
418    let mut num_elems_per_unit = line_size as usize * elems_per_unit;
419
420    let last_dim = output.shape[out_rank - 1];
421
422    // If tensor is strided, elems_per_unit must be compatible with last dim
423    while !last_dim.is_multiple_of(num_elems_per_unit as usize) {
424        elems_per_unit /= 2;
425        num_elems_per_unit /= 2;
426    }
427
428    let out_layout = LinearLayoutArgs::from_handle(client, output, line_size);
429
430    let address_type = input
431        .required_address_type()
432        .max(output.required_address_type());
433    let cube_count = calculate_cube_count_elemwise(
434        client,
435        num_elems.div_ceil(num_elems_per_unit as usize),
436        cube_dim,
437    );
438
439    let in_shape = shape
440        .iter()
441        .map(|s| FastDivmodArgs::<usize>::new(client, *s))
442        .collect();
443
444    copy_kernel_packed::launch(
445        client,
446        cube_count,
447        cube_dim,
448        address_type,
449        input.as_tensor_arg(1),
450        output.as_tensor_arg(line_size),
451        out_layout,
452        in_shape,
453        in_packed_dim,
454        packing,
455        in_rank,
456        elems_per_unit,
457        dtype,
458    )
459}
460
461/// Checks if the tensor associated with the given shape and strides is contiguous.
462pub fn is_contiguous(shape: &[usize], strides: &[usize]) -> bool {
463    if shape.is_empty() {
464        return true;
465    }
466
467    for (&expected, &stride) in compact_strides(shape).iter().zip(strides) {
468        if expected != stride {
469            return false;
470        }
471    }
472
473    true
474}
475
476/// Checks if a tensor is only strided on the last dimension, and could be safely reinterpreted as
477/// a 2D tensor with unit stride on the last dimension. This will always hold for non-permuted
478/// tensors allocated on a runtime.
479pub fn is_contiguous_pitched(shape: &[usize], strides: &[usize]) -> bool {
480    let rank = shape.len();
481    if strides[rank - 1] != 1 {
482        return false;
483    }
484    if rank <= 1 {
485        return true;
486    }
487
488    let mut sorted = strides.to_vec();
489    sorted.sort();
490    sorted.reverse();
491
492    if sorted != strides {
493        return false;
494    }
495
496    for i in 0..rank - 2 {
497        if strides[i] != shape[i + 1] * strides[i + 1] {
498            return false;
499        }
500    }
501    true
502}
503
504pub fn compact_strides(shape: &[usize]) -> Strides {
505    let rank = shape.len();
506    let mut strides = strides![1; rank];
507    for i in (0..rank - 1).rev() {
508        strides[i] = strides[i + 1] * shape[i + 1];
509    }
510    strides
511}