cubecl_std/tensor/
contiguous.rs

1use crate::{FastDivmod, FastDivmodArgs};
2
3use super::TensorHandle;
4use cubecl::prelude::*;
5use cubecl_core::{self as cubecl, calculate_cube_count_elemwise, tensor_line_size_parallel};
6
7pub const NUM_SM_APPROX: u32 = 50;
8
9/// Returns the offset of the tensor corresponding to the layout tensor.
10#[cube]
11pub fn index_offset_with_layout<N: CubePrimitive, L: CubePrimitive>(
12    tensor: &Tensor<Line<N>>,
13    layout: &Tensor<Line<L>>,
14    offset_layout: u32,
15    dim_start: u32,
16    dim_end: u32,
17    #[comptime] unroll: bool,
18) -> u32 {
19    let offset_ref = offset_layout * tensor.line_size();
20    let mut offset = 0;
21
22    #[unroll(unroll)]
23    for i in dim_start..dim_end {
24        let ogwl = offset_ref / layout.stride(i);
25        offset += ogwl % tensor.shape(i) * tensor.stride(i);
26    }
27
28    offset / tensor.line_size()
29}
30
31/// Returns the offset of the tensor corresponding to a contiguous layout.
32#[cube]
33pub fn index_offset_contiguous<N: CubePrimitive>(
34    tensor: &Tensor<Line<N>>,
35    offset_layout: u32,
36    #[comptime] rank: Option<u32>,
37) -> u32 {
38    let unroll = rank.is_some();
39    let rank = rank.unwrap_or_else(|| tensor.rank());
40
41    let offset_ref = offset_layout * tensor.line_size();
42    let mut offset = 0;
43    let mut remainder = offset_ref;
44
45    #[unroll(unroll)]
46    for i in 0..rank {
47        let dim = rank - i - 1;
48        let shape = tensor.shape(dim);
49        let ogwl = remainder % shape;
50        offset += ogwl * tensor.stride(dim);
51        remainder /= shape;
52    }
53
54    offset / tensor.line_size()
55}
56
57/// Returns the offset of the tensor corresponding to a contiguous layout.
58#[cube]
59pub fn index_offset_contiguous_fastdivmod<N: CubePrimitive>(
60    tensor: &Tensor<Line<N>>,
61    offset_layout: u32,
62    shape: &Sequence<FastDivmod>,
63    stride: &Sequence<u32>,
64) -> u32 {
65    let rank = comptime![shape.len()];
66
67    let offset_ref = offset_layout * tensor.line_size();
68    let mut offset = 0;
69    let mut remainder = offset_ref;
70
71    let mut dim = comptime![rank - 1];
72
73    #[unroll]
74    for _ in 0..rank {
75        let shape = shape.index(dim);
76        let (rem, ogwl) = shape.div_mod(remainder);
77        offset += ogwl * stride.index(dim);
78        remainder = rem;
79
80        comptime![dim = dim.saturating_sub(1);]
81    }
82
83    offset / tensor.line_size()
84}
85
86/// Layout for tensor that may or may not be strided on the last dimension. Efficiently translates
87/// the absolute index to strided index.
88#[derive(CubeType, CubeLaunch)]
89pub enum StridedLayout {
90    Pitched(FastDivmod),
91    None,
92}
93
94impl<R: Runtime> StridedLayoutArgs<'_, R> {
95    /// Last dimension is contiguous in second last dimension
96    pub fn none() -> Self {
97        Self::None
98    }
99
100    /// Last dimension is strided with the last dimension having the shape `shape`
101    pub fn strided(client: &ComputeClient<R::Server, R::Channel>, shape: u32) -> Self {
102        Self::Pitched(FastDivmodArgs::new(client, shape))
103    }
104}
105
106#[cube]
107impl StridedLayout {
108    /// Translates absolute index to strided index if applicable
109    pub fn index<T: CubePrimitive>(&self, tensor: &Tensor<Line<T>>, index: u32) -> u32 {
110        match self {
111            StridedLayout::Pitched(divmod) => {
112                let offset_abs = index * tensor.line_size();
113                let (y, x) = divmod.div_mod(offset_abs);
114                let offset = y * tensor.stride(tensor.rank() - 2) + x;
115                offset / tensor.line_size()
116            }
117            StridedLayout::None => index,
118        }
119    }
120}
121
122#[cube(launch)]
123fn into_contiguous_kernel<N: CubePrimitive>(
124    input: &Tensor<Line<N>>,
125    output: &mut Tensor<Line<N>>,
126    out_layout: StridedLayout,
127    shape: Sequence<FastDivmod>,
128    stride: Sequence<u32>,
129    #[comptime] elems_per_thread: u32,
130) {
131    let offset_output = ABSOLUTE_POS * elems_per_thread;
132    let line_size = input.line_size();
133
134    let mut registers = Array::<Line<N>>::vectorized(elems_per_thread, line_size);
135
136    #[unroll]
137    for i in 0..elems_per_thread {
138        let offset_input =
139            index_offset_contiguous_fastdivmod::<N>(input, offset_output + i, &shape, &stride);
140        registers[i] = input[offset_input];
141    }
142
143    let offset_output = out_layout.index(output, offset_output);
144
145    #[unroll]
146    for i in 0..elems_per_thread {
147        output[offset_output + i] = registers[i];
148    }
149}
150
151#[cube(launch)]
152fn into_contiguous_kernel_pack<N: CubePrimitive>(
153    input: &Tensor<Line<N>>,
154    output: &mut Tensor<Line<N>>,
155    out_layout: StridedLayout,
156    shape: Sequence<FastDivmod>,
157    stride: Sequence<u32>,
158    #[comptime] elems_per_thread: u32,
159) {
160    let line_size = output.line_size();
161    let lines_per_thread = comptime![elems_per_thread / line_size];
162
163    let offset_output = ABSOLUTE_POS * lines_per_thread;
164    let offset_input = offset_output * line_size;
165
166    let mut registers = Array::<Line<N>>::vectorized(lines_per_thread, line_size);
167
168    #[unroll]
169    for i in 0..lines_per_thread {
170        let offset = i * line_size;
171        let mut reg = Line::<N>::empty(line_size);
172        #[unroll]
173        for k in 0..line_size {
174            let offset_input = offset_input + offset + k;
175            let offset_input =
176                index_offset_contiguous_fastdivmod::<N>(input, offset_input, &shape, &stride);
177            reg[k] = input[offset_input][0];
178        }
179        registers[i] = reg;
180    }
181
182    let offset_output = out_layout.index(output, offset_output);
183
184    #[unroll]
185    for i in 0..lines_per_thread {
186        output[offset_output + i] = registers[i];
187    }
188}
189
190/// Make a jit tensor contiguous.
191pub fn into_contiguous<R: Runtime, E: CubePrimitive>(
192    client: &ComputeClient<R::Server, R::Channel>,
193    input: &TensorHandleRef<'_, R>,
194) -> TensorHandle<R, E> {
195    let num_elems: usize = input.shape.iter().product();
196
197    let handle = client.empty(num_elems * size_of::<E>());
198    let output = TensorHandle::new_contiguous(input.shape.to_vec(), handle);
199
200    into_contiguous_ref::<R, E>(client, input, &output.as_ref());
201
202    output
203}
204
205/// Make a jit tensor contiguous, using the pitched allocator if available.
206/// See [create_tensor](cubecl_runtime::client::ComputeClient::create_tensor).
207pub fn into_contiguous_pitched<R: Runtime, E: CubePrimitive>(
208    client: &ComputeClient<R::Server, R::Channel>,
209    input: &TensorHandleRef<'_, R>,
210) -> TensorHandle<R, E> {
211    if input.shape.len() <= 1 {
212        return into_contiguous(client, input);
213    }
214
215    let output = TensorHandle::empty(client, input.shape.to_vec());
216
217    into_contiguous_ref::<R, E>(client, input, &output.as_ref());
218
219    output
220}
221
222/// Make a jit tensor contiguous.
223pub fn into_contiguous_ref<R: Runtime, E: CubePrimitive>(
224    client: &ComputeClient<R::Server, R::Channel>,
225    input: &TensorHandleRef<'_, R>,
226    output: &TensorHandleRef<'_, R>,
227) {
228    let num_elems: usize = input.shape.iter().product();
229
230    // Vectorization is only enabled when the last dimension is contiguous.
231    let rank = input.strides.len();
232    let vectorization_factor = tensor_line_size_parallel(
233        R::supported_line_sizes().iter().cloned(),
234        input.shape,
235        input.strides,
236        rank - 1,
237    );
238    let num_vecs = num_elems / vectorization_factor as usize;
239    let num_sm = client
240        .properties()
241        .hardware
242        .num_streaming_multiprocessors
243        .unwrap_or(NUM_SM_APPROX);
244    let simul_vecs = num_sm * CubeDim::default().num_elems();
245    let mut elems_per_unit = match num_vecs as u32 / simul_vecs {
246        0..2 => 1,
247        2..4 => 2,
248        4..8 => 4,
249        8.. => 8,
250    };
251
252    let mut num_elems_per_unit = vectorization_factor as u32 * elems_per_unit;
253
254    let last_dim = output.shape[rank - 1];
255    let is_padded = rank > 1 && last_dim != output.strides[rank - 2];
256
257    // If tensor is strided, elems_per_unit must be compatible with last dim
258    while is_padded && last_dim % num_elems_per_unit as usize != 0 {
259        elems_per_unit /= 2;
260        num_elems_per_unit /= 2;
261    }
262
263    let out_layout = match is_padded {
264        true => StridedLayoutArgs::strided(client, last_dim as u32),
265        false => StridedLayoutArgs::none(),
266    };
267
268    let out_vec = if vectorization_factor > 1 {
269        vectorization_factor
270    } else {
271        *R::supported_line_sizes()
272            .iter()
273            .filter(|it| num_elems_per_unit % **it as u32 == 0)
274            .max()
275            .unwrap_or(&1)
276    };
277
278    let cube_dim = CubeDim::default();
279    let cube_count =
280        calculate_cube_count_elemwise(num_elems.div_ceil(num_elems_per_unit as usize), cube_dim);
281
282    let shape = SequenceArg {
283        values: input
284            .shape
285            .iter()
286            .map(|dim| FastDivmodArgs::new(client, *dim as u32))
287            .collect(),
288    };
289
290    let stride = SequenceArg {
291        values: input
292            .strides
293            .iter()
294            .map(|s| ScalarArg::new(*s as u32))
295            .collect(),
296    };
297
298    let launch = if vectorization_factor != out_vec && out_vec > 1 {
299        into_contiguous_kernel_pack::launch::<E, R>
300    } else {
301        into_contiguous_kernel::launch::<E, R>
302    };
303
304    launch(
305        client,
306        cube_count,
307        cube_dim,
308        input.as_tensor_arg(vectorization_factor),
309        output.as_tensor_arg(out_vec),
310        out_layout,
311        shape,
312        stride,
313        elems_per_unit,
314    );
315}
316
317/// Checks if the tensor associated with the given shape and strides is contiguous.
318pub fn is_contiguous(shape: &[usize], strides: &[usize]) -> bool {
319    if shape.is_empty() {
320        return true;
321    }
322
323    for (expected, &stride) in compact_strides(shape).into_iter().zip(strides) {
324        if expected != stride {
325            return false;
326        }
327    }
328
329    true
330}
331
332pub fn compact_strides(shape: &[usize]) -> Vec<usize> {
333    let rank = shape.len();
334    let mut strides = vec![1; rank];
335    for i in (0..rank - 1).rev() {
336        strides[i] = strides[i + 1] * shape[i + 1];
337    }
338    strides
339}