Skip to main content

cubecl_std/tensor/contiguous/
base.rs

1use crate::{
2    FastDivmod, FastDivmodArgs,
3    tensor::{
4        TensorHandle, into_contiguous_ref,
5        layout::{
6            Layout, LayoutExpand,
7            linear::{LinearLayout, LinearLayoutArgs, LinearView, linear_view},
8        },
9    },
10};
11use cubecl::prelude::*;
12use cubecl_core::{
13    self as cubecl, calculate_cube_count_elemwise,
14    ir::{LineSize, StorageType},
15    tensor_line_size_parallel,
16};
17
18pub const NUM_SM_APPROX: u32 = 50;
19
20/// Returns the offset of the tensor corresponding to the layout tensor.
21#[cube]
22pub fn index_offset_with_layout<N: CubePrimitive, L: CubePrimitive>(
23    tensor: &Tensor<Line<N>>,
24    layout: &Tensor<Line<L>>,
25    offset_layout: usize,
26    dim_start: usize,
27    dim_end: usize,
28    #[comptime] unroll: bool,
29) -> usize {
30    let offset_ref = offset_layout * tensor.line_size();
31    let mut offset = 0;
32
33    #[unroll(unroll)]
34    for i in dim_start..dim_end {
35        let ogwl = offset_ref / layout.stride(i);
36        offset += ogwl % tensor.shape(i) * tensor.stride(i);
37    }
38
39    offset / tensor.line_size()
40}
41
42/// Returns the offset of the tensor corresponding to a contiguous layout.
43#[cube]
44pub fn index_offset_contiguous<N: CubePrimitive>(
45    tensor: &Tensor<Line<N>>,
46    offset_layout: usize,
47    #[comptime] rank: Option<usize>,
48) -> usize {
49    let unroll = rank.is_some();
50    let rank = rank.unwrap_or_else(|| tensor.rank());
51
52    let offset_ref = offset_layout * tensor.line_size();
53    let mut offset = 0;
54    let mut remainder = offset_ref;
55
56    #[unroll(unroll)]
57    for i in 0..rank {
58        let dim = rank - i - 1;
59        let shape = tensor.shape(dim);
60        let ogwl = remainder % shape;
61        offset += ogwl * tensor.stride(dim);
62        remainder /= shape;
63    }
64
65    offset / tensor.line_size()
66}
67
68/// Returns the offset of the tensor corresponding to a contiguous layout.
69#[cube]
70pub fn index_offset_contiguous_fastdivmod(
71    offset: usize,
72    shape: &Sequence<FastDivmod<usize>>,
73    stride: &Sequence<usize>,
74    #[comptime] line_size: LineSize,
75) -> usize {
76    let rank = shape.len().comptime();
77
78    let offset_ref = offset * line_size;
79    let mut offset = 0;
80    let mut remainder = offset_ref;
81
82    #[unroll]
83    for i in 0..rank {
84        let dim = rank - i - 1;
85
86        let (rem, ogwl) = shape[dim].div_mod(remainder);
87        offset += ogwl * stride[dim];
88        remainder = rem;
89    }
90
91    offset / line_size
92}
93
94#[cube(launch)]
95fn copy_kernel<N: Numeric>(
96    input: &LinearView<Line<N>>,
97    output: &mut Tensor<Line<N>>,
98    out_layout: LinearLayout,
99    #[comptime] elems_per_thread: usize,
100    #[define(N)] _elem: StorageType,
101) {
102    let offset_linear = ABSOLUTE_POS * elems_per_thread;
103    let line_size = input.line_size();
104
105    let mut registers = Array::<Line<N>>::lined(elems_per_thread, line_size);
106
107    #[unroll]
108    for i in 0..elems_per_thread {
109        registers[i] = input[offset_linear + i];
110    }
111
112    let offset_output = out_layout.to_source_pos(offset_linear);
113
114    #[unroll]
115    for i in 0..elems_per_thread {
116        output[offset_output + i] = registers[i];
117    }
118}
119
120#[cube(launch)]
121fn copy_kernel_pack<N: Numeric>(
122    input: &LinearView<Line<N>>,
123    output: &mut Tensor<Line<N>>,
124    out_layout: LinearLayout,
125    #[comptime] elems_per_thread: usize,
126    #[define(N)] _elem: StorageType,
127) {
128    let line_size = output.line_size().comptime();
129    let lines_per_thread = elems_per_thread / line_size;
130
131    let offset_output = ABSOLUTE_POS * lines_per_thread;
132    let offset_input = offset_output * line_size;
133
134    let mut registers = Array::<Line<N>>::lined(lines_per_thread, line_size);
135
136    #[unroll]
137    for i in 0..lines_per_thread {
138        let offset = i * line_size;
139        let mut reg = Line::<N>::empty(line_size);
140        #[unroll]
141        for k in 0..line_size {
142            let offset_input = offset_input + offset + k;
143            reg[k] = input[offset_input][0];
144        }
145        registers[i] = reg;
146    }
147
148    let offset_output = out_layout.to_source_pos(offset_output);
149
150    #[unroll]
151    for i in 0..lines_per_thread {
152        output[offset_output + i] = registers[i];
153    }
154}
155
156/// Fetch all values required contained in a given position, unpack them, then repack them to their
157/// new position.
158#[cube]
159fn index_packed<N: Int>(
160    tensor: &Tensor<N>,
161    pos: usize,
162    in_shape: &Sequence<FastDivmod<usize>>,
163    #[comptime] packed_dim: usize,
164    #[comptime] packing: usize,
165    #[comptime] rank: usize,
166) -> N {
167    let type_size_bits = N::type_size_bits().comptime();
168    let bits_per_elem = type_size_bits / packing;
169    let mask = (1u32 << bits_per_elem) - 1;
170    let mask = N::cast_from(mask);
171
172    let elem_pos = pos * packing;
173
174    let mut out = N::new(0);
175    for n in 0..packing {
176        let mut remainder = elem_pos + n;
177        let mut offset = 0;
178        let mut packing_offset = 0;
179
180        #[unroll]
181        for i in 0..rank {
182            let dim = rank - i - 1;
183            let (rem, mut local_pos) = in_shape[dim].div_mod(remainder);
184            remainder = rem;
185            if dim == packed_dim {
186                packing_offset = local_pos % packing;
187                local_pos /= packing;
188            }
189            offset += local_pos * tensor.stride(dim);
190        }
191        let packed_val = tensor[offset];
192        let shift_in = packing_offset * bits_per_elem;
193        let shift_out = n * bits_per_elem;
194        let value = (packed_val >> N::cast_from(shift_in)) & mask;
195
196        out |= value << N::cast_from(shift_out);
197    }
198    out
199}
200
201#[cube(launch)]
202fn copy_kernel_packed<N: Int>(
203    input: &Tensor<N>,
204    output: &mut Tensor<Line<N>>,
205    out_layout: LinearLayout,
206    in_shape: Sequence<FastDivmod<usize>>,
207    #[comptime] packed_dim: usize,
208    #[comptime] packing: usize,
209    #[comptime] rank: usize,
210    #[comptime] elems_per_thread: usize,
211    #[define(N)] _elem: StorageType,
212) {
213    let line_size = output.line_size().comptime();
214    let lines_per_thread = elems_per_thread / line_size;
215
216    let offset_output = ABSOLUTE_POS * lines_per_thread;
217    let offset_input = offset_output * line_size;
218
219    if offset_output >= output.len() {
220        terminate!()
221    }
222
223    let mut registers = Array::<Line<N>>::lined(lines_per_thread, line_size);
224
225    #[unroll]
226    for i in 0..lines_per_thread {
227        let offset = i * line_size;
228        let mut reg = Line::<N>::empty(line_size);
229        #[unroll]
230        for k in 0..line_size {
231            let offset_input = offset_input + offset + k;
232
233            reg[k] = index_packed(input, offset_input, &in_shape, packed_dim, packing, rank);
234        }
235        registers[i] = reg;
236    }
237
238    let offset_output = out_layout.to_source_pos(offset_output);
239
240    #[unroll]
241    for i in 0..lines_per_thread {
242        output[offset_output + i] = registers[i];
243    }
244}
245
246/// Make a jit tensor contiguous, using the pitched allocator if available.
247/// See [`create_tensor`](cubecl_runtime::client::ComputeClient::create_tensor).
248/// Handles unpacking and repacking packed tensors (i.e. quantized values).
249/// `shape` refers to the actual (unpacked) shape of the tensor, while `packing` specifies the
250/// number of elements in each storage element.
251///
252/// # Warning
253/// This assumes `u32` or `u8` packing.
254pub fn into_contiguous_packed<R: Runtime>(
255    client: &ComputeClient<R>,
256    input: &TensorHandleRef<'_, R>,
257    packed_dim: usize,
258    shape: &[usize],
259    packing: usize,
260    dtype: StorageType,
261) -> Result<TensorHandle<R>, LaunchError> {
262    let rank = shape.len();
263    if rank <= 1 {
264        return into_contiguous_ref(client, input, dtype);
265    }
266
267    let mut out_shape = shape.to_vec();
268    out_shape[rank - 1] = out_shape[rank - 1].div_ceil(packing);
269    let output = TensorHandle::empty(client, out_shape, dtype);
270
271    // Should reinterpret as u8 if possible at some point, but requires modifying shape/strides so
272    // keep it simple for now
273    into_contiguous_packed_ref(
274        client,
275        input,
276        &output.as_ref(),
277        packed_dim,
278        shape,
279        packing,
280        dtype,
281    )?;
282
283    Ok(output)
284}
285
286/// Make a jit tensor contiguous.
287pub fn copy_gpu_ref<R: Runtime>(
288    client: &ComputeClient<R>,
289    input: &TensorHandleRef<'_, R>,
290    output: &TensorHandleRef<'_, R>,
291    dtype: StorageType,
292) -> Result<(), LaunchError> {
293    let num_elems: usize = input.shape.iter().product();
294
295    // Vectorization is only enabled when the last dimension is contiguous.
296    let in_rank = input.strides.len();
297    let out_rank = output.strides.len();
298    let line_size_in = tensor_line_size_parallel(
299        client.io_optimized_line_sizes(&dtype),
300        input.shape,
301        input.strides,
302        in_rank - 1,
303    );
304    let line_size_out = tensor_line_size_parallel(
305        client.io_optimized_line_sizes(&dtype),
306        output.shape,
307        output.strides,
308        out_rank - 1,
309    );
310    let line_size = line_size_in.min(line_size_out);
311
312    let num_vecs = num_elems / line_size as usize;
313    let num_sm = client
314        .properties()
315        .hardware
316        .num_streaming_multiprocessors
317        .unwrap_or(NUM_SM_APPROX);
318    let cube_dim = CubeDim::new(client, num_vecs);
319    let simul_vecs = num_sm * cube_dim.num_elems();
320    let mut elems_per_unit = match num_vecs / simul_vecs as usize {
321        0..2 => 1,
322        2..4 => 2,
323        4..8 => 4,
324        8.. => 8,
325    };
326
327    let mut num_elems_per_unit = line_size as usize * elems_per_unit;
328
329    let last_dim = output.shape[out_rank - 1];
330
331    // If tensor is strided, elems_per_unit must be compatible with last dim
332    while !last_dim.is_multiple_of(num_elems_per_unit as usize) {
333        elems_per_unit /= 2;
334        num_elems_per_unit /= 2;
335    }
336
337    let out_vec = if line_size > 1 {
338        line_size
339    } else {
340        // Recompute because it needs to account for `num_elems_per_unit`
341        client
342            .io_optimized_line_sizes(&dtype)
343            .filter(|it| num_elems_per_unit.is_multiple_of(*it))
344            .max()
345            .unwrap_or(1)
346    };
347
348    let input = linear_view(client, input, line_size);
349    let out_layout = LinearLayoutArgs::from_handle(client, output, out_vec);
350
351    let cube_count = calculate_cube_count_elemwise(
352        client,
353        num_elems.div_ceil(num_elems_per_unit as usize),
354        cube_dim,
355    );
356
357    let launch = if line_size != out_vec && out_vec > 1 {
358        copy_kernel_pack::launch
359    } else {
360        copy_kernel::launch
361    };
362
363    launch(
364        client,
365        cube_count,
366        cube_dim,
367        input,
368        output.as_tensor_arg(out_vec),
369        out_layout,
370        elems_per_unit,
371        dtype,
372    )
373}
374
375/// Make a jit tensor contiguous.
376pub fn into_contiguous_packed_ref<R: Runtime>(
377    client: &ComputeClient<R>,
378    input: &TensorHandleRef<'_, R>,
379    output: &TensorHandleRef<'_, R>,
380    packed_dim: usize,
381    shape: &[usize],
382    packing: usize,
383    dtype: StorageType,
384) -> Result<(), LaunchError> {
385    let num_elems: usize = input.shape.iter().product();
386
387    // Vectorization is only enabled when the last dimension is contiguous.
388    let in_rank = input.strides.len();
389    let out_rank = output.strides.len();
390    let in_packed_dim = in_rank - packed_dim - 1;
391    let line_size = tensor_line_size_parallel(
392        client.io_optimized_line_sizes(&dtype),
393        output.shape,
394        output.strides,
395        out_rank - 1,
396    );
397    let num_vecs = num_elems / line_size as usize;
398    let num_sm = client
399        .properties()
400        .hardware
401        .num_streaming_multiprocessors
402        .unwrap_or(NUM_SM_APPROX);
403
404    let cube_dim = CubeDim::new(client, num_vecs);
405    let simul_vecs = num_sm * cube_dim.num_elems();
406    let mut elems_per_unit = match num_vecs / simul_vecs as usize {
407        0..2 => 1,
408        2..4 => 2,
409        4..8 => 4,
410        8.. => 8,
411    };
412
413    let mut num_elems_per_unit = line_size as usize * elems_per_unit;
414
415    let last_dim = output.shape[out_rank - 1];
416
417    // If tensor is strided, elems_per_unit must be compatible with last dim
418    while !last_dim.is_multiple_of(num_elems_per_unit as usize) {
419        elems_per_unit /= 2;
420        num_elems_per_unit /= 2;
421    }
422
423    let out_layout = LinearLayoutArgs::from_handle(client, output, line_size);
424
425    let cube_count = calculate_cube_count_elemwise(
426        client,
427        num_elems.div_ceil(num_elems_per_unit as usize),
428        cube_dim,
429    );
430
431    let in_shape = shape
432        .iter()
433        .map(|s| FastDivmodArgs::<usize>::new(client, *s))
434        .collect();
435
436    copy_kernel_packed::launch(
437        client,
438        cube_count,
439        cube_dim,
440        input.as_tensor_arg(1),
441        output.as_tensor_arg(line_size),
442        out_layout,
443        in_shape,
444        in_packed_dim,
445        packing,
446        in_rank,
447        elems_per_unit,
448        dtype,
449    )
450}
451
452/// Checks if the tensor associated with the given shape and strides is contiguous.
453pub fn is_contiguous(shape: &[usize], strides: &[usize]) -> bool {
454    if shape.is_empty() {
455        return true;
456    }
457
458    for (expected, &stride) in compact_strides(shape).into_iter().zip(strides) {
459        if expected != stride {
460            return false;
461        }
462    }
463
464    true
465}
466
467/// Checks if a tensor is only strided on the last dimension, and could be safely reinterpreted as
468/// a 2D tensor with unit stride on the last dimension. This will always hold for non-permuted
469/// tensors allocated on a runtime.
470pub fn is_contiguous_pitched(shape: &[usize], strides: &[usize]) -> bool {
471    let rank = shape.len();
472    if strides[rank - 1] != 1 {
473        return false;
474    }
475    if rank <= 1 {
476        return true;
477    }
478
479    let mut sorted = strides.to_vec();
480    sorted.sort();
481    sorted.reverse();
482
483    if sorted != strides {
484        return false;
485    }
486
487    for i in 0..rank - 2 {
488        if strides[i] != shape[i + 1] * strides[i + 1] {
489            return false;
490        }
491    }
492    true
493}
494
495pub fn compact_strides(shape: &[usize]) -> Vec<usize> {
496    let rank = shape.len();
497    let mut strides = vec![1; rank];
498    for i in (0..rank - 1).rev() {
499        strides[i] = strides[i + 1] * shape[i + 1];
500    }
501    strides
502}