cubecl_std/tensor/
contiguous.rs

1use crate::{
2    FastDivmod, FastDivmodArgs,
3    tensor::layout::{
4        Layout, LayoutExpand,
5        linear::{LinearLayout, LinearLayoutArgs, LinearView, linear_view},
6    },
7};
8
9use super::TensorHandle;
10use cubecl::prelude::*;
11use cubecl_core::{self as cubecl, calculate_cube_count_elemwise, tensor_line_size_parallel};
12
13pub const NUM_SM_APPROX: u32 = 50;
14
15/// Returns the offset of the tensor corresponding to the layout tensor.
16#[cube]
17pub fn index_offset_with_layout<N: CubePrimitive, L: CubePrimitive>(
18    tensor: &Tensor<Line<N>>,
19    layout: &Tensor<Line<L>>,
20    offset_layout: u32,
21    dim_start: u32,
22    dim_end: u32,
23    #[comptime] unroll: bool,
24) -> u32 {
25    let offset_ref = offset_layout * tensor.line_size();
26    let mut offset = 0;
27
28    #[unroll(unroll)]
29    for i in dim_start..dim_end {
30        let ogwl = offset_ref / layout.stride(i);
31        offset += ogwl % tensor.shape(i) * tensor.stride(i);
32    }
33
34    offset / tensor.line_size()
35}
36
37/// Returns the offset of the tensor corresponding to a contiguous layout.
38#[cube]
39pub fn index_offset_contiguous<N: CubePrimitive>(
40    tensor: &Tensor<Line<N>>,
41    offset_layout: u32,
42    #[comptime] rank: Option<u32>,
43) -> u32 {
44    let unroll = rank.is_some();
45    let rank = rank.unwrap_or_else(|| tensor.rank());
46
47    let offset_ref = offset_layout * tensor.line_size();
48    let mut offset = 0;
49    let mut remainder = offset_ref;
50
51    #[unroll(unroll)]
52    for i in 0..rank {
53        let dim = rank - i - 1;
54        let shape = tensor.shape(dim);
55        let ogwl = remainder % shape;
56        offset += ogwl * tensor.stride(dim);
57        remainder /= shape;
58    }
59
60    offset / tensor.line_size()
61}
62
63/// Returns the offset of the tensor corresponding to a contiguous layout.
64#[cube]
65pub fn index_offset_contiguous_fastdivmod(
66    offset: u32,
67    shape: &Sequence<FastDivmod>,
68    stride: &Sequence<u32>,
69    #[comptime] line_size: u32,
70) -> u32 {
71    let rank = comptime![shape.len()];
72
73    let offset_ref = offset * line_size;
74    let mut offset = 0;
75    let mut remainder = offset_ref;
76
77    let mut dim = comptime![rank - 1];
78
79    #[unroll]
80    for _ in 0..rank {
81        let shape = shape.index(dim);
82        let (rem, ogwl) = shape.div_mod(remainder);
83        offset += ogwl * stride.index(dim);
84        remainder = rem;
85
86        comptime![dim = dim.saturating_sub(1);]
87    }
88
89    offset / line_size
90}
91
92#[cube(launch)]
93fn into_contiguous_kernel<N: CubePrimitive>(
94    input: &LinearView<Line<N>>,
95    output: &mut Tensor<Line<N>>,
96    out_layout: LinearLayout,
97    #[comptime] elems_per_thread: u32,
98) {
99    let offset_output = ABSOLUTE_POS * elems_per_thread;
100    let line_size = input.line_size();
101
102    let mut registers = Array::<Line<N>>::vectorized(elems_per_thread, line_size);
103
104    #[unroll]
105    for i in 0..elems_per_thread {
106        registers[i] = input[offset_output + i];
107    }
108
109    let offset_output = out_layout.to_source_pos(offset_output);
110
111    #[unroll]
112    for i in 0..elems_per_thread {
113        output[offset_output + i] = registers[i];
114    }
115}
116
117#[cube(launch)]
118fn into_contiguous_kernel_pack<N: CubePrimitive>(
119    input: &LinearView<Line<N>>,
120    output: &mut Tensor<Line<N>>,
121    out_layout: LinearLayout,
122    #[comptime] elems_per_thread: u32,
123) {
124    let line_size = output.line_size();
125    let lines_per_thread = comptime![elems_per_thread / line_size];
126
127    let offset_output = ABSOLUTE_POS * lines_per_thread;
128    let offset_input = offset_output * line_size;
129
130    let mut registers = Array::<Line<N>>::vectorized(lines_per_thread, line_size);
131
132    #[unroll]
133    for i in 0..lines_per_thread {
134        let offset = i * line_size;
135        let mut reg = Line::<N>::empty(line_size);
136        #[unroll]
137        for k in 0..line_size {
138            let offset_input = offset_input + offset + k;
139            reg[k] = input[offset_input][0];
140        }
141        registers[i] = reg;
142    }
143
144    let offset_output = out_layout.to_source_pos(offset_output);
145
146    #[unroll]
147    for i in 0..lines_per_thread {
148        output[offset_output + i] = registers[i];
149    }
150}
151
152/// Fetch all values required contained in a given position, unpack them, then repack them to their
153/// new position.
154#[cube]
155fn index_packed<N: Int>(
156    tensor: &Tensor<N>,
157    pos: u32,
158    in_shape: &Sequence<FastDivmod>,
159    #[comptime] packed_dim: u32,
160    #[comptime] packing: u32,
161    #[comptime] rank: u32,
162) -> N {
163    let bits_per_elem = comptime![N::elem_size_bits() / packing];
164    let mask = comptime![(1u32 << bits_per_elem) - 1];
165    let mask = N::cast_from(mask);
166
167    let elem_pos = pos * packing;
168
169    let mut out = N::new(0);
170    for n in 0..packing {
171        let mut remainder = elem_pos + n;
172        let mut offset = 0;
173        let mut packing_offset = 0;
174
175        #[unroll]
176        for i in 0..rank {
177            let dim = comptime![rank - i - 1];
178            let (rem, mut local_pos) = in_shape.index(dim).div_mod(remainder);
179            remainder = rem;
180            if comptime![dim == packed_dim] {
181                packing_offset = local_pos % packing;
182                local_pos /= packing;
183            }
184            offset += local_pos * tensor.stride(dim);
185        }
186        let packed_val = tensor[offset];
187        let shift_in = packing_offset * bits_per_elem;
188        let shift_out = n * bits_per_elem;
189        let value = (packed_val >> N::cast_from(shift_in)) & mask;
190
191        out |= value << N::cast_from(shift_out);
192    }
193    out
194}
195
196#[cube(launch)]
197fn into_contiguous_kernel_packed<N: Int>(
198    input: &Tensor<N>,
199    output: &mut Tensor<Line<N>>,
200    out_layout: LinearLayout,
201    in_shape: Sequence<FastDivmod>,
202    #[comptime] packed_dim: u32,
203    #[comptime] packing: u32,
204    #[comptime] rank: u32,
205    #[comptime] elems_per_thread: u32,
206) {
207    let line_size = output.line_size();
208    let lines_per_thread = comptime![elems_per_thread / line_size];
209
210    let offset_output = ABSOLUTE_POS * lines_per_thread;
211    let offset_input = offset_output * line_size;
212
213    if offset_output >= output.len() {
214        terminate!()
215    }
216
217    let mut registers = Array::<Line<N>>::vectorized(lines_per_thread, line_size);
218
219    #[unroll]
220    for i in 0..lines_per_thread {
221        let offset = i * line_size;
222        let mut reg = Line::<N>::empty(line_size);
223        #[unroll]
224        for k in 0..line_size {
225            let offset_input = offset_input + offset + k;
226
227            reg[k] = index_packed(input, offset_input, &in_shape, packed_dim, packing, rank);
228        }
229        registers[i] = reg;
230    }
231
232    let offset_output = out_layout.to_source_pos(offset_output);
233
234    #[unroll]
235    for i in 0..lines_per_thread {
236        output[offset_output + i] = registers[i];
237    }
238}
239
240/// Make a jit tensor contiguous.
241pub fn into_contiguous<R: Runtime, E: CubePrimitive>(
242    client: &ComputeClient<R::Server>,
243    input: &TensorHandleRef<'_, R>,
244) -> TensorHandle<R, E> {
245    let num_elems: usize = input.shape.iter().product();
246
247    let handle = client.empty(num_elems * size_of::<E>());
248    let output = TensorHandle::new_contiguous(input.shape.to_vec(), handle);
249
250    into_contiguous_ref::<R, E>(client, input, &output.as_ref());
251
252    output
253}
254
255/// Make a jit tensor contiguous, using the pitched allocator if available.
256/// See [create_tensor](cubecl_runtime::client::ComputeClient::create_tensor).
257pub fn into_contiguous_pitched<R: Runtime, E: CubePrimitive>(
258    client: &ComputeClient<R::Server>,
259    input: &TensorHandleRef<'_, R>,
260) -> TensorHandle<R, E> {
261    if input.shape.len() <= 1 {
262        return into_contiguous(client, input);
263    }
264
265    let output = TensorHandle::empty(client, input.shape.to_vec());
266
267    into_contiguous_ref::<R, E>(client, input, &output.as_ref());
268
269    output
270}
271
272/// Make a jit tensor contiguous, using the pitched allocator if available.
273/// See [create_tensor](cubecl_runtime::client::ComputeClient::create_tensor).
274/// Handles unpacking and repacking packed tensors (i.e. quantized values).
275/// `shape` refers to the actual (unpacked) shape of the tensor, while `packing` specifies the
276/// number of elements in each storage element.
277///
278/// # Warning
279/// This assumes `u32` or `u8` packing.
280pub fn into_contiguous_packed<R: Runtime, I: Int>(
281    client: &ComputeClient<R::Server>,
282    input: &TensorHandleRef<'_, R>,
283    shape: &[usize],
284    packing: u32,
285) -> TensorHandle<R, I> {
286    let rank = shape.len();
287    if rank <= 1 {
288        return into_contiguous(client, input);
289    }
290
291    let mut out_shape = shape.to_vec();
292    out_shape[rank - 1] = out_shape[rank - 1].div_ceil(packing as usize);
293    let output = TensorHandle::<R, I>::empty(client, out_shape);
294
295    // Should reinterpret as u8 if possible at some point, but requires modifying shape/strides so
296    // keep it simple for now
297    into_contiguous_packed_ref::<R, I>(client, input, &output.as_ref(), shape, packing);
298
299    output
300}
301
302/// Make a jit tensor contiguous.
303pub fn into_contiguous_ref<R: Runtime, E: CubePrimitive>(
304    client: &ComputeClient<R::Server>,
305    input: &TensorHandleRef<'_, R>,
306    output: &TensorHandleRef<'_, R>,
307) {
308    let num_elems: usize = input.shape.iter().product();
309
310    // Vectorization is only enabled when the last dimension is contiguous.
311    let rank = input.strides.len();
312    let line_size = tensor_line_size_parallel(
313        R::supported_line_sizes().iter().cloned(),
314        input.shape,
315        input.strides,
316        rank - 1,
317    );
318    let num_vecs = num_elems / line_size as usize;
319    let num_sm = client
320        .properties()
321        .hardware
322        .num_streaming_multiprocessors
323        .unwrap_or(NUM_SM_APPROX);
324    let simul_vecs = num_sm * CubeDim::default().num_elems();
325    let mut elems_per_unit = match num_vecs as u32 / simul_vecs {
326        0..2 => 1,
327        2..4 => 2,
328        4..8 => 4,
329        8.. => 8,
330    };
331
332    let mut num_elems_per_unit = line_size as u32 * elems_per_unit;
333
334    let last_dim = output.shape[rank - 1];
335
336    // If tensor is strided, elems_per_unit must be compatible with last dim
337    while !last_dim.is_multiple_of(num_elems_per_unit as usize) {
338        elems_per_unit /= 2;
339        num_elems_per_unit /= 2;
340    }
341
342    let out_vec = if line_size > 1 {
343        line_size
344    } else {
345        *R::supported_line_sizes()
346            .iter()
347            .filter(|it| num_elems_per_unit.is_multiple_of(**it as u32))
348            .max()
349            .unwrap_or(&1)
350    };
351
352    let input = linear_view(client, input, line_size);
353    let out_layout = LinearLayoutArgs::from_handle(client, output, out_vec);
354
355    let cube_dim = CubeDim::default();
356    let cube_count =
357        calculate_cube_count_elemwise(num_elems.div_ceil(num_elems_per_unit as usize), cube_dim);
358
359    let launch = if line_size != out_vec && out_vec > 1 {
360        into_contiguous_kernel_pack::launch::<E, R>
361    } else {
362        into_contiguous_kernel::launch::<E, R>
363    };
364
365    launch(
366        client,
367        cube_count,
368        cube_dim,
369        input,
370        output.as_tensor_arg(out_vec),
371        out_layout,
372        elems_per_unit,
373    );
374}
375
376/// Make a jit tensor contiguous.
377pub fn into_contiguous_packed_ref<R: Runtime, E: Int>(
378    client: &ComputeClient<R::Server>,
379    input: &TensorHandleRef<'_, R>,
380    output: &TensorHandleRef<'_, R>,
381    shape: &[usize],
382    packing: u32,
383) {
384    let num_elems: usize = input.shape.iter().product();
385
386    // Vectorization is only enabled when the last dimension is contiguous.
387    let rank = input.strides.len();
388    let line_size = tensor_line_size_parallel(
389        R::io_optimized_line_sizes(&E::as_type_native_unchecked()),
390        output.shape,
391        output.strides,
392        rank - 1,
393    );
394    let num_vecs = num_elems / line_size as usize;
395    let num_sm = client
396        .properties()
397        .hardware
398        .num_streaming_multiprocessors
399        .unwrap_or(NUM_SM_APPROX);
400    let simul_vecs = num_sm * CubeDim::default().num_elems();
401    let mut elems_per_unit = match num_vecs as u32 / simul_vecs {
402        0..2 => 1,
403        2..4 => 2,
404        4..8 => 4,
405        8.. => 8,
406    };
407
408    let mut num_elems_per_unit = line_size as u32 * elems_per_unit;
409
410    let last_dim = output.shape[rank - 1];
411    let packed_dim = input
412        .strides
413        .iter()
414        .enumerate()
415        .rev()
416        .find(|(_, s)| **s == 1)
417        .expect("At least one stride should be 1")
418        .0;
419
420    // If tensor is strided, elems_per_unit must be compatible with last dim
421    while !last_dim.is_multiple_of(num_elems_per_unit as usize) {
422        elems_per_unit /= 2;
423        num_elems_per_unit /= 2;
424    }
425
426    let out_layout = LinearLayoutArgs::from_handle(client, output, line_size);
427
428    let cube_dim = CubeDim::default();
429    let cube_count =
430        calculate_cube_count_elemwise(num_elems.div_ceil(num_elems_per_unit as usize), cube_dim);
431
432    let in_shape = shape
433        .iter()
434        .map(|s| FastDivmodArgs::new(client, *s as u32))
435        .collect();
436
437    into_contiguous_kernel_packed::launch::<E, R>(
438        client,
439        cube_count,
440        cube_dim,
441        input.as_tensor_arg(1),
442        output.as_tensor_arg(line_size),
443        out_layout,
444        in_shape,
445        packed_dim as u32,
446        packing,
447        rank as u32,
448        elems_per_unit,
449    );
450}
451
452/// Checks if the tensor associated with the given shape and strides is contiguous.
453pub fn is_contiguous(shape: &[usize], strides: &[usize]) -> bool {
454    if shape.is_empty() {
455        return true;
456    }
457
458    for (expected, &stride) in compact_strides(shape).into_iter().zip(strides) {
459        if expected != stride {
460            return false;
461        }
462    }
463
464    true
465}
466
467/// Checks if a tensor is only strided on the last dimension, and could be safely reinterpreted as
468/// a 2D tensor with unit stride on the last dimension. This will always hold for non-permuted
469/// tensors allocated on a runtime.
470pub fn is_contiguous_pitched(shape: &[usize], strides: &[usize]) -> bool {
471    let rank = shape.len();
472    if strides[rank - 1] != 1 {
473        return false;
474    }
475    if rank <= 1 {
476        return true;
477    }
478
479    let mut sorted = strides.to_vec();
480    sorted.sort();
481    sorted.reverse();
482
483    if sorted != strides {
484        return false;
485    }
486
487    for i in 0..rank - 2 {
488        if strides[i] != shape[i + 1] * strides[i + 1] {
489            return false;
490        }
491    }
492    true
493}
494
495pub fn compact_strides(shape: &[usize]) -> Vec<usize> {
496    let rank = shape.len();
497    let mut strides = vec![1; rank];
498    for i in (0..rank - 1).rev() {
499        strides[i] = strides[i + 1] * shape[i + 1];
500    }
501    strides
502}