cubecl_linalg/tensor/
contiguous.rs

1use super::TensorHandle;
2use cubecl::prelude::*;
3use cubecl_core::{self as cubecl, calculate_cube_count_elemwise, tensor_line_size_parallel};
4use cubecl_std::{FastDivmod, FastDivmodArgs};
5
6/// Returns the offset of the tensor corresponding to the layout tensor.
7#[cube]
8pub fn index_offset_with_layout<N: CubePrimitive, L: CubePrimitive>(
9    tensor: &Tensor<Line<N>>,
10    layout: &Tensor<Line<L>>,
11    offset_layout: u32,
12    dim_start: u32,
13    dim_end: u32,
14    #[comptime] unroll: bool,
15) -> u32 {
16    let offset_ref = offset_layout * tensor.line_size();
17    let mut offset = 0;
18
19    #[unroll(unroll)]
20    for i in dim_start..dim_end {
21        let ogwl = offset_ref / layout.stride(i);
22        offset += ogwl % tensor.shape(i) * tensor.stride(i);
23    }
24
25    offset / tensor.line_size()
26}
27
28/// Returns the offset of the tensor corresponding to a contiguous layout.
29#[cube]
30pub fn index_offset_contiguous<N: CubePrimitive>(
31    tensor: &Tensor<Line<N>>,
32    offset_layout: u32,
33    #[comptime] rank: Option<u32>,
34) -> u32 {
35    let unroll = rank.is_some();
36    let rank = rank.unwrap_or_else(|| tensor.rank());
37
38    let offset_ref = offset_layout * tensor.line_size();
39    let mut offset = 0;
40    let mut remainder = offset_ref;
41
42    #[unroll(unroll)]
43    for i in 0..rank {
44        let dim = rank - i - 1;
45        let shape = tensor.shape(dim);
46        let ogwl = remainder % shape;
47        offset += ogwl * tensor.stride(dim);
48        remainder /= shape;
49    }
50
51    offset / tensor.line_size()
52}
53
54/// Layout for tensor that may or may not be strided on the last dimension. Efficiently translates
55/// the absolute index to strided index.
56#[derive(CubeType, CubeLaunch)]
57pub enum StridedLayout {
58    Pitched(FastDivmod),
59    None,
60}
61
62impl<R: Runtime> StridedLayoutArgs<'_, R> {
63    /// Last dimension is contiguous in second last dimension
64    pub fn none() -> Self {
65        Self::None
66    }
67
68    /// Last dimension is strided with the last dimension having the shape `shape`
69    pub fn strided(client: &ComputeClient<R::Server, R::Channel>, shape: u32) -> Self {
70        Self::Pitched(FastDivmodArgs::new(client, shape))
71    }
72}
73
74#[cube]
75impl StridedLayout {
76    /// Translates absolute index to strided index if applicable
77    pub fn index<T: CubePrimitive>(&self, tensor: &Tensor<Line<T>>, index: u32) -> u32 {
78        match self {
79            StridedLayout::Pitched(divmod) => {
80                let offset_abs = index * tensor.line_size();
81                let (y, x) = divmod.div_mod(offset_abs);
82                let offset = y * tensor.stride(tensor.rank() - 2) + x;
83                offset / tensor.line_size()
84            }
85            StridedLayout::None => index,
86        }
87    }
88}
89
90#[cube(launch)]
91fn into_contiguous_kernel<N: CubePrimitive>(
92    input: &Tensor<Line<N>>,
93    output: &mut Tensor<Line<N>>,
94    out_layout: StridedLayout,
95    #[comptime] rank: Option<u32>,
96    #[comptime] elems_per_thread: u32,
97) {
98    let offset_output = ABSOLUTE_POS * elems_per_thread;
99    let line_size = input.line_size();
100
101    let mut registers = Array::vectorized(elems_per_thread, line_size);
102
103    #[unroll]
104    for i in 0..elems_per_thread {
105        let offset_input = index_offset_contiguous::<N>(input, offset_output + i, rank);
106        registers[i] = input[offset_input];
107    }
108
109    let offset_output = out_layout.index(output, offset_output);
110
111    #[unroll]
112    for i in 0..elems_per_thread {
113        output[offset_output + i] = registers[i];
114    }
115}
116
117/// Make a jit tensor contiguous.
118pub fn into_contiguous<R: Runtime, E: CubePrimitive>(
119    client: &ComputeClient<R::Server, R::Channel>,
120    input: &TensorHandleRef<'_, R>,
121) -> TensorHandle<R, E> {
122    let num_elems: usize = input.shape.iter().product();
123    // Vectorization is only enabled when the last dimension is contiguous.
124    let rank = input.strides.len();
125    let vectorization_factor = tensor_line_size_parallel(
126        R::supported_line_sizes().iter().cloned(),
127        input.shape,
128        input.strides,
129        rank - 1,
130    );
131    let num_vecs = num_elems / vectorization_factor as usize;
132    let approx_sm = 64;
133    let approx_simul_vecs = approx_sm * CubeDim::default().num_elems();
134    let elems_per_unit = match num_vecs as u32 / approx_simul_vecs {
135        0..2 => 1,
136        2..4 => 2,
137        4..8 => 4,
138        8.. => 8,
139    };
140
141    // TODO: Benchmark to find good default prefetch, for now preserve existing behaviour
142    into_contiguous_prefetch(client, input, elems_per_unit, false)
143}
144
145/// Make a jit tensor contiguous, using the pitched allocator if available.
146/// See [create_tensor](cubecl_runtime::client::ComputeClient::create_tensor).
147pub fn into_contiguous_pitched<R: Runtime, E: CubePrimitive>(
148    client: &ComputeClient<R::Server, R::Channel>,
149    input: &TensorHandleRef<'_, R>,
150) -> TensorHandle<R, E> {
151    if input.shape.len() <= 1 {
152        return into_contiguous(client, input);
153    }
154
155    let num_elems: usize = input.shape.iter().product();
156    // Vectorization is only enabled when the last dimension is contiguous.
157    let rank = input.strides.len();
158    let vectorization_factor = tensor_line_size_parallel(
159        R::supported_line_sizes().iter().cloned(),
160        input.shape,
161        input.strides,
162        rank - 1,
163    );
164    let num_vecs = num_elems / vectorization_factor as usize;
165    let approx_sm = 64;
166    let approx_simul_vecs = approx_sm * CubeDim::default().num_elems();
167    let elems_per_unit = match num_vecs as u32 / approx_simul_vecs {
168        0..2 => 1,
169        2..4 => 2,
170        4..8 => 4,
171        8.. => 8,
172    };
173
174    // TODO: Benchmark to find good default prefetch, for now preserve existing behaviour
175    into_contiguous_prefetch(client, input, elems_per_unit, true)
176}
177
178/// Make a jit tensor contiguous.
179pub fn into_contiguous_prefetch<R: Runtime, E: CubePrimitive>(
180    client: &ComputeClient<R::Server, R::Channel>,
181    input: &TensorHandleRef<'_, R>,
182    mut elems_per_unit: u32,
183    pitched: bool,
184) -> TensorHandle<R, E> {
185    // Vectorization is only enabled when the last dimension is contiguous.
186    let rank = input.strides.len();
187    let vectorization_factor = tensor_line_size_parallel(
188        R::supported_line_sizes().iter().cloned(),
189        input.shape,
190        input.strides,
191        rank - 1,
192    );
193
194    let num_elems: usize = input.shape.iter().product();
195    let output = if pitched {
196        TensorHandle::empty(client, input.shape.to_vec())
197    } else {
198        let handle = client.empty(num_elems * size_of::<E>());
199        TensorHandle::new_contiguous(input.shape.to_vec(), handle)
200    };
201
202    let mut num_elems_per_unit = vectorization_factor as u32 * elems_per_unit;
203
204    let last_dim = output.shape[rank - 1];
205    let is_padded = rank > 1 && last_dim != output.strides[rank - 2];
206
207    // If tensor is strided, elems_per_unit must be compatible with last dim
208    while is_padded && last_dim % num_elems_per_unit as usize != 0 {
209        elems_per_unit /= 2;
210        num_elems_per_unit /= 2;
211    }
212
213    let out_layout = match is_padded {
214        true => StridedLayoutArgs::strided(client, last_dim as u32),
215        false => StridedLayoutArgs::none(),
216    };
217
218    let cube_dim = CubeDim::default();
219    let cube_count =
220        calculate_cube_count_elemwise(num_elems.div_ceil(num_elems_per_unit as usize), cube_dim);
221
222    into_contiguous_kernel::launch::<Line<E>, R>(
223        client,
224        cube_count,
225        cube_dim,
226        input.as_tensor_arg(vectorization_factor),
227        output.as_ref().as_tensor_arg(vectorization_factor),
228        out_layout,
229        Some(rank as u32),
230        elems_per_unit,
231    );
232
233    output
234}