cubecl_std/tensor/
contiguous.rs

1use crate::{
2    FastDivmod, FastDivmodArgs,
3    tensor::layout::{
4        Layout, LayoutExpand,
5        linear::{LinearLayout, LinearLayoutArgs, LinearView, linear_view},
6    },
7};
8
9use super::TensorHandle;
10use cubecl::prelude::*;
11use cubecl_core::{
12    self as cubecl, calculate_cube_count_elemwise, ir::StorageType, tensor_line_size_parallel,
13};
14
15pub const NUM_SM_APPROX: u32 = 50;
16
17/// Returns the offset of the tensor corresponding to the layout tensor.
18#[cube]
19pub fn index_offset_with_layout<N: CubePrimitive, L: CubePrimitive>(
20    tensor: &Tensor<Line<N>>,
21    layout: &Tensor<Line<L>>,
22    offset_layout: u32,
23    dim_start: u32,
24    dim_end: u32,
25    #[comptime] unroll: bool,
26) -> u32 {
27    let offset_ref = offset_layout * tensor.line_size();
28    let mut offset = 0;
29
30    #[unroll(unroll)]
31    for i in dim_start..dim_end {
32        let ogwl = offset_ref / layout.stride(i);
33        offset += ogwl % tensor.shape(i) * tensor.stride(i);
34    }
35
36    offset / tensor.line_size()
37}
38
39/// Returns the offset of the tensor corresponding to a contiguous layout.
40#[cube]
41pub fn index_offset_contiguous<N: CubePrimitive>(
42    tensor: &Tensor<Line<N>>,
43    offset_layout: u32,
44    #[comptime] rank: Option<u32>,
45) -> u32 {
46    let unroll = rank.is_some();
47    let rank = rank.unwrap_or_else(|| tensor.rank());
48
49    let offset_ref = offset_layout * tensor.line_size();
50    let mut offset = 0;
51    let mut remainder = offset_ref;
52
53    #[unroll(unroll)]
54    for i in 0..rank {
55        let dim = rank - i - 1;
56        let shape = tensor.shape(dim);
57        let ogwl = remainder % shape;
58        offset += ogwl * tensor.stride(dim);
59        remainder /= shape;
60    }
61
62    offset / tensor.line_size()
63}
64
65/// Returns the offset of the tensor corresponding to a contiguous layout.
66#[cube]
67pub fn index_offset_contiguous_fastdivmod(
68    offset: u32,
69    shape: &Sequence<FastDivmod>,
70    stride: &Sequence<u32>,
71    #[comptime] line_size: u32,
72) -> u32 {
73    let rank = comptime![shape.len()];
74
75    let offset_ref = offset * line_size;
76    let mut offset = 0;
77    let mut remainder = offset_ref;
78
79    let mut dim = comptime![rank - 1];
80
81    #[unroll]
82    for _ in 0..rank {
83        let shape = shape.index(dim);
84        let (rem, ogwl) = shape.div_mod(remainder);
85        offset += ogwl * stride.index(dim);
86        remainder = rem;
87
88        comptime![dim = dim.saturating_sub(1);]
89    }
90
91    offset / line_size
92}
93
94#[cube(launch)]
95fn into_contiguous_kernel<N: Numeric>(
96    input: &LinearView<Line<N>>,
97    output: &mut Tensor<Line<N>>,
98    out_layout: LinearLayout,
99    #[comptime] elems_per_thread: u32,
100    #[define(N)] _elem: StorageType,
101) {
102    let offset_output = ABSOLUTE_POS * elems_per_thread;
103    let line_size = input.line_size();
104
105    let mut registers = Array::<Line<N>>::vectorized(elems_per_thread, line_size);
106
107    #[unroll]
108    for i in 0..elems_per_thread {
109        registers[i] = input[offset_output + i];
110    }
111
112    let offset_output = out_layout.to_source_pos(offset_output);
113
114    #[unroll]
115    for i in 0..elems_per_thread {
116        output[offset_output + i] = registers[i];
117    }
118}
119
120#[cube(launch)]
121fn into_contiguous_kernel_pack<N: Numeric>(
122    input: &LinearView<Line<N>>,
123    output: &mut Tensor<Line<N>>,
124    out_layout: LinearLayout,
125    #[comptime] elems_per_thread: u32,
126    #[define(N)] _elem: StorageType,
127) {
128    let line_size = output.line_size();
129    let lines_per_thread = comptime![elems_per_thread / line_size];
130
131    let offset_output = ABSOLUTE_POS * lines_per_thread;
132    let offset_input = offset_output * line_size;
133
134    let mut registers = Array::<Line<N>>::vectorized(lines_per_thread, line_size);
135
136    #[unroll]
137    for i in 0..lines_per_thread {
138        let offset = i * line_size;
139        let mut reg = Line::<N>::empty(line_size);
140        #[unroll]
141        for k in 0..line_size {
142            let offset_input = offset_input + offset + k;
143            reg[k] = input[offset_input][0];
144        }
145        registers[i] = reg;
146    }
147
148    let offset_output = out_layout.to_source_pos(offset_output);
149
150    #[unroll]
151    for i in 0..lines_per_thread {
152        output[offset_output + i] = registers[i];
153    }
154}
155
156/// Fetch all values required contained in a given position, unpack them, then repack them to their
157/// new position.
158#[cube]
159fn index_packed<N: Int>(
160    tensor: &Tensor<N>,
161    pos: u32,
162    in_shape: &Sequence<FastDivmod>,
163    #[comptime] packed_dim: u32,
164    #[comptime] packing: u32,
165    #[comptime] rank: u32,
166) -> N {
167    let elem_size_bits = N::elem_size_bits();
168    let bits_per_elem = comptime![elem_size_bits / packing];
169    let mask = comptime![(1u32 << bits_per_elem) - 1];
170    let mask = N::cast_from(mask);
171
172    let elem_pos = pos * packing;
173
174    let mut out = N::new(0);
175    for n in 0..packing {
176        let mut remainder = elem_pos + n;
177        let mut offset = 0;
178        let mut packing_offset = 0;
179
180        #[unroll]
181        for i in 0..rank {
182            let dim = comptime![rank - i - 1];
183            let (rem, mut local_pos) = in_shape.index(dim).div_mod(remainder);
184            remainder = rem;
185            if comptime![dim == packed_dim] {
186                packing_offset = local_pos % packing;
187                local_pos /= packing;
188            }
189            offset += local_pos * tensor.stride(dim);
190        }
191        let packed_val = tensor[offset];
192        let shift_in = packing_offset * bits_per_elem;
193        let shift_out = n * bits_per_elem;
194        let value = (packed_val >> N::cast_from(shift_in)) & mask;
195
196        out |= value << N::cast_from(shift_out);
197    }
198    out
199}
200
201#[cube(launch)]
202fn into_contiguous_kernel_packed<N: Int>(
203    input: &Tensor<N>,
204    output: &mut Tensor<Line<N>>,
205    out_layout: LinearLayout,
206    in_shape: Sequence<FastDivmod>,
207    #[comptime] packed_dim: u32,
208    #[comptime] packing: u32,
209    #[comptime] rank: u32,
210    #[comptime] elems_per_thread: u32,
211    #[define(N)] _elem: StorageType,
212) {
213    let line_size = output.line_size();
214    let lines_per_thread = comptime![elems_per_thread / line_size];
215
216    let offset_output = ABSOLUTE_POS * lines_per_thread;
217    let offset_input = offset_output * line_size;
218
219    if offset_output >= output.len() {
220        terminate!()
221    }
222
223    let mut registers = Array::<Line<N>>::vectorized(lines_per_thread, line_size);
224
225    #[unroll]
226    for i in 0..lines_per_thread {
227        let offset = i * line_size;
228        let mut reg = Line::<N>::empty(line_size);
229        #[unroll]
230        for k in 0..line_size {
231            let offset_input = offset_input + offset + k;
232
233            reg[k] = index_packed(input, offset_input, &in_shape, packed_dim, packing, rank);
234        }
235        registers[i] = reg;
236    }
237
238    let offset_output = out_layout.to_source_pos(offset_output);
239
240    #[unroll]
241    for i in 0..lines_per_thread {
242        output[offset_output + i] = registers[i];
243    }
244}
245
246/// Make a jit tensor contiguous.
247pub fn into_contiguous<R: Runtime>(
248    client: &ComputeClient<R::Server>,
249    input: &TensorHandleRef<'_, R>,
250    dtype: StorageType,
251) -> TensorHandle<R> {
252    let num_elems: usize = input.shape.iter().product();
253
254    let handle = client.empty(num_elems * dtype.size());
255    let output = TensorHandle::new_contiguous(input.shape.to_vec(), handle, dtype);
256
257    into_contiguous_ref::<R>(client, input, &output.as_ref(), dtype);
258
259    output
260}
261
262/// Make a jit tensor contiguous, using the pitched allocator if available.
263/// See [create_tensor](cubecl_runtime::client::ComputeClient::create_tensor).
264pub fn into_contiguous_pitched<R: Runtime>(
265    client: &ComputeClient<R::Server>,
266    input: &TensorHandleRef<'_, R>,
267    dtype: StorageType,
268) -> TensorHandle<R> {
269    if input.shape.len() <= 1 {
270        return into_contiguous(client, input, dtype);
271    }
272
273    let output = TensorHandle::empty(client, input.shape.to_vec(), dtype);
274
275    into_contiguous_ref::<R>(client, input, &output.as_ref(), dtype);
276
277    output
278}
279
280/// Make a jit tensor contiguous, using the pitched allocator if available.
281/// See [create_tensor](cubecl_runtime::client::ComputeClient::create_tensor).
282/// Handles unpacking and repacking packed tensors (i.e. quantized values).
283/// `shape` refers to the actual (unpacked) shape of the tensor, while `packing` specifies the
284/// number of elements in each storage element.
285///
286/// # Warning
287/// This assumes `u32` or `u8` packing.
288pub fn into_contiguous_packed<R: Runtime>(
289    client: &ComputeClient<R::Server>,
290    input: &TensorHandleRef<'_, R>,
291    shape: &[usize],
292    packing: u32,
293    dtype: StorageType,
294) -> TensorHandle<R> {
295    let rank = shape.len();
296    if rank <= 1 {
297        return into_contiguous(client, input, dtype);
298    }
299
300    let mut out_shape = shape.to_vec();
301    out_shape[rank - 1] = out_shape[rank - 1].div_ceil(packing as usize);
302    let output = TensorHandle::<R>::empty(client, out_shape, dtype);
303
304    // Should reinterpret as u8 if possible at some point, but requires modifying shape/strides so
305    // keep it simple for now
306    into_contiguous_packed_ref::<R>(client, input, &output.as_ref(), shape, packing, dtype);
307
308    output
309}
310
311/// Make a jit tensor contiguous.
312pub fn into_contiguous_ref<R: Runtime>(
313    client: &ComputeClient<R::Server>,
314    input: &TensorHandleRef<'_, R>,
315    output: &TensorHandleRef<'_, R>,
316    dtype: StorageType,
317) {
318    let num_elems: usize = input.shape.iter().product();
319
320    // Vectorization is only enabled when the last dimension is contiguous.
321    let rank = input.strides.len();
322    let line_size = tensor_line_size_parallel(
323        R::supported_line_sizes().iter().cloned(),
324        input.shape,
325        input.strides,
326        rank - 1,
327    );
328    let num_vecs = num_elems / line_size as usize;
329    let num_sm = client
330        .properties()
331        .hardware
332        .num_streaming_multiprocessors
333        .unwrap_or(NUM_SM_APPROX);
334    let simul_vecs = num_sm * CubeDim::default().num_elems();
335    let mut elems_per_unit = match num_vecs as u32 / simul_vecs {
336        0..2 => 1,
337        2..4 => 2,
338        4..8 => 4,
339        8.. => 8,
340    };
341
342    let mut num_elems_per_unit = line_size as u32 * elems_per_unit;
343
344    let last_dim = output.shape[rank - 1];
345
346    // If tensor is strided, elems_per_unit must be compatible with last dim
347    while !last_dim.is_multiple_of(num_elems_per_unit as usize) {
348        elems_per_unit /= 2;
349        num_elems_per_unit /= 2;
350    }
351
352    let out_vec = if line_size > 1 {
353        line_size
354    } else {
355        *R::supported_line_sizes()
356            .iter()
357            .filter(|it| num_elems_per_unit.is_multiple_of(**it as u32))
358            .max()
359            .unwrap_or(&1)
360    };
361
362    let input = linear_view(client, input, line_size);
363    let out_layout = LinearLayoutArgs::from_handle(client, output, out_vec);
364
365    let cube_dim = CubeDim::default();
366    let cube_count =
367        calculate_cube_count_elemwise(num_elems.div_ceil(num_elems_per_unit as usize), cube_dim);
368
369    let launch = if line_size != out_vec && out_vec > 1 {
370        into_contiguous_kernel_pack::launch::<R>
371    } else {
372        into_contiguous_kernel::launch::<R>
373    };
374
375    launch(
376        client,
377        cube_count,
378        cube_dim,
379        input,
380        output.as_tensor_arg(out_vec),
381        out_layout,
382        elems_per_unit,
383        dtype,
384    );
385}
386
387/// Make a jit tensor contiguous.
388pub fn into_contiguous_packed_ref<R: Runtime>(
389    client: &ComputeClient<R::Server>,
390    input: &TensorHandleRef<'_, R>,
391    output: &TensorHandleRef<'_, R>,
392    shape: &[usize],
393    packing: u32,
394    dtype: StorageType,
395) {
396    let num_elems: usize = input.shape.iter().product();
397
398    // Vectorization is only enabled when the last dimension is contiguous.
399    let rank = input.strides.len();
400    let line_size = tensor_line_size_parallel(
401        R::io_optimized_line_sizes(&dtype),
402        output.shape,
403        output.strides,
404        rank - 1,
405    );
406    let num_vecs = num_elems / line_size as usize;
407    let num_sm = client
408        .properties()
409        .hardware
410        .num_streaming_multiprocessors
411        .unwrap_or(NUM_SM_APPROX);
412    let simul_vecs = num_sm * CubeDim::default().num_elems();
413    let mut elems_per_unit = match num_vecs as u32 / simul_vecs {
414        0..2 => 1,
415        2..4 => 2,
416        4..8 => 4,
417        8.. => 8,
418    };
419
420    let mut num_elems_per_unit = line_size as u32 * elems_per_unit;
421
422    let last_dim = output.shape[rank - 1];
423    let packed_dim = input
424        .strides
425        .iter()
426        .enumerate()
427        .rev()
428        .find(|(_, s)| **s == 1)
429        .expect("At least one stride should be 1")
430        .0;
431
432    // If tensor is strided, elems_per_unit must be compatible with last dim
433    while !last_dim.is_multiple_of(num_elems_per_unit as usize) {
434        elems_per_unit /= 2;
435        num_elems_per_unit /= 2;
436    }
437
438    let out_layout = LinearLayoutArgs::from_handle(client, output, line_size);
439
440    let cube_dim = CubeDim::default();
441    let cube_count =
442        calculate_cube_count_elemwise(num_elems.div_ceil(num_elems_per_unit as usize), cube_dim);
443
444    let in_shape = shape
445        .iter()
446        .map(|s| FastDivmodArgs::new(client, *s as u32))
447        .collect();
448
449    into_contiguous_kernel_packed::launch::<R>(
450        client,
451        cube_count,
452        cube_dim,
453        input.as_tensor_arg(1),
454        output.as_tensor_arg(line_size),
455        out_layout,
456        in_shape,
457        packed_dim as u32,
458        packing,
459        rank as u32,
460        elems_per_unit,
461        dtype,
462    );
463}
464
465/// Checks if the tensor associated with the given shape and strides is contiguous.
466pub fn is_contiguous(shape: &[usize], strides: &[usize]) -> bool {
467    if shape.is_empty() {
468        return true;
469    }
470
471    for (expected, &stride) in compact_strides(shape).into_iter().zip(strides) {
472        if expected != stride {
473            return false;
474        }
475    }
476
477    true
478}
479
480/// Checks if a tensor is only strided on the last dimension, and could be safely reinterpreted as
481/// a 2D tensor with unit stride on the last dimension. This will always hold for non-permuted
482/// tensors allocated on a runtime.
483pub fn is_contiguous_pitched(shape: &[usize], strides: &[usize]) -> bool {
484    let rank = shape.len();
485    if strides[rank - 1] != 1 {
486        return false;
487    }
488    if rank <= 1 {
489        return true;
490    }
491
492    let mut sorted = strides.to_vec();
493    sorted.sort();
494    sorted.reverse();
495
496    if sorted != strides {
497        return false;
498    }
499
500    for i in 0..rank - 2 {
501        if strides[i] != shape[i + 1] * strides[i + 1] {
502            return false;
503        }
504    }
505    true
506}
507
508pub fn compact_strides(shape: &[usize]) -> Vec<usize> {
509    let rank = shape.len();
510    let mut strides = vec![1; rank];
511    for i in (0..rank - 1).rev() {
512        strides[i] = strides[i + 1] * shape[i + 1];
513    }
514    strides
515}