cubecl_std/tensor/contiguous/
base.rs

1use crate::{
2    FastDivmod, FastDivmodArgs,
3    tensor::{
4        TensorHandle, into_contiguous_ref,
5        layout::{
6            Layout, LayoutExpand,
7            linear::{LinearLayout, LinearLayoutArgs, LinearView, linear_view},
8        },
9    },
10};
11use cubecl::prelude::*;
12use cubecl_core::{
13    self as cubecl, calculate_cube_count_elemwise,
14    ir::{LineSize, StorageType},
15    tensor_line_size_parallel,
16};
17
18pub const NUM_SM_APPROX: u32 = 50;
19
20/// Returns the offset of the tensor corresponding to the layout tensor.
21#[cube]
22pub fn index_offset_with_layout<N: CubePrimitive, L: CubePrimitive>(
23    tensor: &Tensor<Line<N>>,
24    layout: &Tensor<Line<L>>,
25    offset_layout: usize,
26    dim_start: usize,
27    dim_end: usize,
28    #[comptime] unroll: bool,
29) -> usize {
30    let offset_ref = offset_layout * tensor.line_size();
31    let mut offset = 0;
32
33    #[unroll(unroll)]
34    for i in dim_start..dim_end {
35        let ogwl = offset_ref / layout.stride(i);
36        offset += ogwl % tensor.shape(i) * tensor.stride(i);
37    }
38
39    offset / tensor.line_size()
40}
41
42/// Returns the offset of the tensor corresponding to a contiguous layout.
43#[cube]
44pub fn index_offset_contiguous<N: CubePrimitive>(
45    tensor: &Tensor<Line<N>>,
46    offset_layout: usize,
47    #[comptime] rank: Option<usize>,
48) -> usize {
49    let unroll = rank.is_some();
50    let rank = rank.unwrap_or_else(|| tensor.rank());
51
52    let offset_ref = offset_layout * tensor.line_size();
53    let mut offset = 0;
54    let mut remainder = offset_ref;
55
56    #[unroll(unroll)]
57    for i in 0..rank {
58        let dim = rank - i - 1;
59        let shape = tensor.shape(dim);
60        let ogwl = remainder % shape;
61        offset += ogwl * tensor.stride(dim);
62        remainder /= shape;
63    }
64
65    offset / tensor.line_size()
66}
67
68/// Returns the offset of the tensor corresponding to a contiguous layout.
69#[cube]
70pub fn index_offset_contiguous_fastdivmod(
71    offset: usize,
72    shape: &Sequence<FastDivmod<usize>>,
73    stride: &Sequence<usize>,
74    #[comptime] line_size: LineSize,
75) -> usize {
76    let rank = shape.len().comptime();
77
78    let offset_ref = offset * line_size;
79    let mut offset = 0;
80    let mut remainder = offset_ref;
81
82    #[unroll]
83    for i in 0..rank {
84        let dim = rank - i - 1;
85
86        let (rem, ogwl) = shape[dim].div_mod(remainder);
87        offset += ogwl * stride[dim];
88        remainder = rem;
89    }
90
91    offset / line_size
92}
93
94#[cube(launch)]
95fn into_contiguous_kernel<N: Numeric>(
96    input: &LinearView<Line<N>>,
97    output: &mut Tensor<Line<N>>,
98    out_layout: LinearLayout,
99    #[comptime] elems_per_thread: usize,
100    #[define(N)] _elem: StorageType,
101) {
102    let offset_linear = ABSOLUTE_POS * elems_per_thread;
103    let line_size = input.line_size();
104
105    let mut registers = Array::<Line<N>>::lined(elems_per_thread, line_size);
106
107    #[unroll]
108    for i in 0..elems_per_thread {
109        registers[i] = input[offset_linear + i];
110    }
111
112    let offset_output = out_layout.to_source_pos(offset_linear);
113
114    #[unroll]
115    for i in 0..elems_per_thread {
116        output[offset_output + i] = registers[i];
117    }
118}
119
120#[cube(launch)]
121fn into_contiguous_kernel_pack<N: Numeric>(
122    input: &LinearView<Line<N>>,
123    output: &mut Tensor<Line<N>>,
124    out_layout: LinearLayout,
125    #[comptime] elems_per_thread: usize,
126    #[define(N)] _elem: StorageType,
127) {
128    let line_size = output.line_size().comptime();
129    let lines_per_thread = elems_per_thread / line_size;
130
131    let offset_output = ABSOLUTE_POS * lines_per_thread;
132    let offset_input = offset_output * line_size;
133
134    let mut registers = Array::<Line<N>>::lined(lines_per_thread, line_size);
135
136    #[unroll]
137    for i in 0..lines_per_thread {
138        let offset = i * line_size;
139        let mut reg = Line::<N>::empty(line_size);
140        #[unroll]
141        for k in 0..line_size {
142            let offset_input = offset_input + offset + k;
143            reg[k] = input[offset_input][0];
144        }
145        registers[i] = reg;
146    }
147
148    let offset_output = out_layout.to_source_pos(offset_output);
149
150    #[unroll]
151    for i in 0..lines_per_thread {
152        output[offset_output + i] = registers[i];
153    }
154}
155
156/// Fetch all values required contained in a given position, unpack them, then repack them to their
157/// new position.
158#[cube]
159fn index_packed<N: Int>(
160    tensor: &Tensor<N>,
161    pos: usize,
162    in_shape: &Sequence<FastDivmod<usize>>,
163    #[comptime] packed_dim: usize,
164    #[comptime] packing: usize,
165    #[comptime] rank: usize,
166) -> N {
167    let type_size_bits = N::type_size_bits().comptime();
168    let bits_per_elem = type_size_bits / packing;
169    let mask = (1u32 << bits_per_elem) - 1;
170    let mask = N::cast_from(mask);
171
172    let elem_pos = pos * packing;
173
174    let mut out = N::new(0);
175    for n in 0..packing {
176        let mut remainder = elem_pos + n;
177        let mut offset = 0;
178        let mut packing_offset = 0;
179
180        #[unroll]
181        for i in 0..rank {
182            let dim = rank - i - 1;
183            let (rem, mut local_pos) = in_shape[dim].div_mod(remainder);
184            remainder = rem;
185            if dim == packed_dim {
186                packing_offset = local_pos % packing;
187                local_pos /= packing;
188            }
189            offset += local_pos * tensor.stride(dim);
190        }
191        let packed_val = tensor[offset];
192        let shift_in = packing_offset * bits_per_elem;
193        let shift_out = n * bits_per_elem;
194        let value = (packed_val >> N::cast_from(shift_in)) & mask;
195
196        out |= value << N::cast_from(shift_out);
197    }
198    out
199}
200
201#[cube(launch)]
202fn into_contiguous_kernel_packed<N: Int>(
203    input: &Tensor<N>,
204    output: &mut Tensor<Line<N>>,
205    out_layout: LinearLayout,
206    in_shape: Sequence<FastDivmod<usize>>,
207    #[comptime] packed_dim: usize,
208    #[comptime] packing: usize,
209    #[comptime] rank: usize,
210    #[comptime] elems_per_thread: usize,
211    #[define(N)] _elem: StorageType,
212) {
213    let line_size = output.line_size().comptime();
214    let lines_per_thread = elems_per_thread / line_size;
215
216    let offset_output = ABSOLUTE_POS * lines_per_thread;
217    let offset_input = offset_output * line_size;
218
219    if offset_output >= output.len() {
220        terminate!()
221    }
222
223    let mut registers = Array::<Line<N>>::lined(lines_per_thread, line_size);
224
225    #[unroll]
226    for i in 0..lines_per_thread {
227        let offset = i * line_size;
228        let mut reg = Line::<N>::empty(line_size);
229        #[unroll]
230        for k in 0..line_size {
231            let offset_input = offset_input + offset + k;
232
233            reg[k] = index_packed(input, offset_input, &in_shape, packed_dim, packing, rank);
234        }
235        registers[i] = reg;
236    }
237
238    let offset_output = out_layout.to_source_pos(offset_output);
239
240    #[unroll]
241    for i in 0..lines_per_thread {
242        output[offset_output + i] = registers[i];
243    }
244}
245
246/// Make a jit tensor contiguous, using the pitched allocator if available.
247/// See [create_tensor](cubecl_runtime::client::ComputeClient::create_tensor).
248/// Handles unpacking and repacking packed tensors (i.e. quantized values).
249/// `shape` refers to the actual (unpacked) shape of the tensor, while `packing` specifies the
250/// number of elements in each storage element.
251///
252/// # Warning
253/// This assumes `u32` or `u8` packing.
254pub fn into_contiguous_packed<R: Runtime>(
255    client: &ComputeClient<R>,
256    input: &TensorHandleRef<'_, R>,
257    packed_dim: usize,
258    shape: &[usize],
259    packing: usize,
260    dtype: StorageType,
261) -> Result<TensorHandle<R>, LaunchError> {
262    let rank = shape.len();
263    if rank <= 1 {
264        return into_contiguous_ref(client, input, dtype);
265    }
266
267    let mut out_shape = shape.to_vec();
268    out_shape[rank - 1] = out_shape[rank - 1].div_ceil(packing);
269    let output = TensorHandle::empty(client, out_shape, dtype);
270
271    // Should reinterpret as u8 if possible at some point, but requires modifying shape/strides so
272    // keep it simple for now
273    into_contiguous_packed_ref(
274        client,
275        input,
276        &output.as_ref(),
277        packed_dim,
278        shape,
279        packing,
280        dtype,
281    )?;
282
283    Ok(output)
284}
285
286/// Make a jit tensor contiguous.
287pub fn into_contiguous_gpu_ref<R: Runtime>(
288    client: &ComputeClient<R>,
289    input: &TensorHandleRef<'_, R>,
290    output: &TensorHandleRef<'_, R>,
291    dtype: StorageType,
292) -> Result<(), LaunchError> {
293    let num_elems: usize = input.shape.iter().product();
294
295    // Vectorization is only enabled when the last dimension is contiguous.
296    let rank = input.strides.len();
297    let line_size = tensor_line_size_parallel(
298        client.io_optimized_line_sizes(&dtype),
299        input.shape,
300        input.strides,
301        rank - 1,
302    );
303    let num_vecs = num_elems / line_size as usize;
304    let num_sm = client
305        .properties()
306        .hardware
307        .num_streaming_multiprocessors
308        .unwrap_or(NUM_SM_APPROX);
309    let cube_dim = CubeDim::new(client, num_vecs);
310    let simul_vecs = num_sm * cube_dim.num_elems();
311    let mut elems_per_unit = match num_vecs / simul_vecs as usize {
312        0..2 => 1,
313        2..4 => 2,
314        4..8 => 4,
315        8.. => 8,
316    };
317
318    let mut num_elems_per_unit = line_size as usize * elems_per_unit;
319
320    let last_dim = output.shape[rank - 1];
321
322    // If tensor is strided, elems_per_unit must be compatible with last dim
323    while !last_dim.is_multiple_of(num_elems_per_unit as usize) {
324        elems_per_unit /= 2;
325        num_elems_per_unit /= 2;
326    }
327
328    let out_vec = if line_size > 1 {
329        line_size
330    } else {
331        client
332            .io_optimized_line_sizes(&dtype)
333            .filter(|it| num_elems_per_unit.is_multiple_of(*it))
334            .max()
335            .unwrap_or(1)
336    };
337
338    let input = linear_view(client, input, line_size);
339    let out_layout = LinearLayoutArgs::from_handle(client, output, out_vec);
340
341    let cube_count = calculate_cube_count_elemwise(
342        client,
343        num_elems.div_ceil(num_elems_per_unit as usize),
344        cube_dim,
345    );
346
347    let launch = if line_size != out_vec && out_vec > 1 {
348        into_contiguous_kernel_pack::launch
349    } else {
350        into_contiguous_kernel::launch
351    };
352
353    launch(
354        client,
355        cube_count,
356        cube_dim,
357        input,
358        output.as_tensor_arg(out_vec),
359        out_layout,
360        elems_per_unit,
361        dtype,
362    )
363}
364
365/// Make a jit tensor contiguous.
366pub fn into_contiguous_packed_ref<R: Runtime>(
367    client: &ComputeClient<R>,
368    input: &TensorHandleRef<'_, R>,
369    output: &TensorHandleRef<'_, R>,
370    packed_dim: usize,
371    shape: &[usize],
372    packing: usize,
373    dtype: StorageType,
374) -> Result<(), LaunchError> {
375    let num_elems: usize = input.shape.iter().product();
376
377    // Vectorization is only enabled when the last dimension is contiguous.
378    let rank = input.strides.len();
379    let packed_dim = rank - packed_dim - 1;
380    let line_size = tensor_line_size_parallel(
381        client.io_optimized_line_sizes(&dtype),
382        output.shape,
383        output.strides,
384        rank - 1,
385    );
386    let num_vecs = num_elems / line_size as usize;
387    let num_sm = client
388        .properties()
389        .hardware
390        .num_streaming_multiprocessors
391        .unwrap_or(NUM_SM_APPROX);
392
393    let cube_dim = CubeDim::new(client, num_vecs);
394    let simul_vecs = num_sm * cube_dim.num_elems();
395    let mut elems_per_unit = match num_vecs / simul_vecs as usize {
396        0..2 => 1,
397        2..4 => 2,
398        4..8 => 4,
399        8.. => 8,
400    };
401
402    let mut num_elems_per_unit = line_size as usize * elems_per_unit;
403
404    let last_dim = output.shape[rank - 1];
405
406    // If tensor is strided, elems_per_unit must be compatible with last dim
407    while !last_dim.is_multiple_of(num_elems_per_unit as usize) {
408        elems_per_unit /= 2;
409        num_elems_per_unit /= 2;
410    }
411
412    let out_layout = LinearLayoutArgs::from_handle(client, output, line_size);
413
414    let cube_count = calculate_cube_count_elemwise(
415        client,
416        num_elems.div_ceil(num_elems_per_unit as usize),
417        cube_dim,
418    );
419
420    let in_shape = shape
421        .iter()
422        .map(|s| FastDivmodArgs::<usize>::new(client, *s))
423        .collect();
424
425    into_contiguous_kernel_packed::launch(
426        client,
427        cube_count,
428        cube_dim,
429        input.as_tensor_arg(1),
430        output.as_tensor_arg(line_size),
431        out_layout,
432        in_shape,
433        packed_dim,
434        packing,
435        rank,
436        elems_per_unit,
437        dtype,
438    )
439}
440
441/// Checks if the tensor associated with the given shape and strides is contiguous.
442pub fn is_contiguous(shape: &[usize], strides: &[usize]) -> bool {
443    if shape.is_empty() {
444        return true;
445    }
446
447    for (expected, &stride) in compact_strides(shape).into_iter().zip(strides) {
448        if expected != stride {
449            return false;
450        }
451    }
452
453    true
454}
455
456/// Checks if a tensor is only strided on the last dimension, and could be safely reinterpreted as
457/// a 2D tensor with unit stride on the last dimension. This will always hold for non-permuted
458/// tensors allocated on a runtime.
459pub fn is_contiguous_pitched(shape: &[usize], strides: &[usize]) -> bool {
460    let rank = shape.len();
461    if strides[rank - 1] != 1 {
462        return false;
463    }
464    if rank <= 1 {
465        return true;
466    }
467
468    let mut sorted = strides.to_vec();
469    sorted.sort();
470    sorted.reverse();
471
472    if sorted != strides {
473        return false;
474    }
475
476    for i in 0..rank - 2 {
477        if strides[i] != shape[i + 1] * strides[i + 1] {
478            return false;
479        }
480    }
481    true
482}
483
484pub fn compact_strides(shape: &[usize]) -> Vec<usize> {
485    let rank = shape.len();
486    let mut strides = vec![1; rank];
487    for i in (0..rank - 1).rev() {
488        strides[i] = strides[i + 1] * shape[i + 1];
489    }
490    strides
491}