1use super::TensorHandle;
2use cubecl::prelude::*;
3use cubecl_core::{self as cubecl, calculate_cube_count_elemwise, tensor_line_size_parallel};
4use cubecl_std::{FastDivmod, FastDivmodArgs};
5
6#[cube]
8pub fn index_offset_with_layout<N: CubePrimitive, L: CubePrimitive>(
9 tensor: &Tensor<Line<N>>,
10 layout: &Tensor<Line<L>>,
11 offset_layout: u32,
12 dim_start: u32,
13 dim_end: u32,
14 #[comptime] unroll: bool,
15) -> u32 {
16 let offset_ref = offset_layout * tensor.line_size();
17 let mut offset = 0;
18
19 #[unroll(unroll)]
20 for i in dim_start..dim_end {
21 let ogwl = offset_ref / layout.stride(i);
22 offset += ogwl % tensor.shape(i) * tensor.stride(i);
23 }
24
25 offset / tensor.line_size()
26}
27
28#[cube]
30pub fn index_offset_contiguous<N: CubePrimitive>(
31 tensor: &Tensor<Line<N>>,
32 offset_layout: u32,
33 #[comptime] rank: Option<u32>,
34) -> u32 {
35 let unroll = rank.is_some();
36 let rank = rank.unwrap_or_else(|| tensor.rank());
37
38 let offset_ref = offset_layout * tensor.line_size();
39 let mut offset = 0;
40 let mut remainder = offset_ref;
41
42 #[unroll(unroll)]
43 for i in 0..rank {
44 let dim = rank - i - 1;
45 let shape = tensor.shape(dim);
46 let ogwl = remainder % shape;
47 offset += ogwl * tensor.stride(dim);
48 remainder /= shape;
49 }
50
51 offset / tensor.line_size()
52}
53
54#[derive(CubeType, CubeLaunch)]
57pub enum StridedLayout {
58 Pitched(FastDivmod),
59 None,
60}
61
62impl<R: Runtime> StridedLayoutArgs<'_, R> {
63 pub fn none() -> Self {
65 Self::None
66 }
67
68 pub fn strided(client: &ComputeClient<R::Server, R::Channel>, shape: u32) -> Self {
70 Self::Pitched(FastDivmodArgs::new(client, shape))
71 }
72}
73
74#[cube]
75impl StridedLayout {
76 pub fn index<T: CubePrimitive>(&self, tensor: &Tensor<Line<T>>, index: u32) -> u32 {
78 match self {
79 StridedLayout::Pitched(divmod) => {
80 let offset_abs = index * tensor.line_size();
81 let (y, x) = divmod.div_mod(offset_abs);
82 let offset = y * tensor.stride(tensor.rank() - 2) + x;
83 offset / tensor.line_size()
84 }
85 StridedLayout::None => index,
86 }
87 }
88}
89
90#[cube(launch)]
91fn into_contiguous_kernel<N: CubePrimitive>(
92 input: &Tensor<Line<N>>,
93 output: &mut Tensor<Line<N>>,
94 out_layout: StridedLayout,
95 #[comptime] rank: Option<u32>,
96 #[comptime] elems_per_thread: u32,
97) {
98 let offset_output = ABSOLUTE_POS * elems_per_thread;
99 let line_size = input.line_size();
100
101 let mut registers = Array::vectorized(elems_per_thread, line_size);
102
103 #[unroll]
104 for i in 0..elems_per_thread {
105 let offset_input = index_offset_contiguous::<N>(input, offset_output + i, rank);
106 registers[i] = input[offset_input];
107 }
108
109 let offset_output = out_layout.index(output, offset_output);
110
111 #[unroll]
112 for i in 0..elems_per_thread {
113 output[offset_output + i] = registers[i];
114 }
115}
116
117pub fn into_contiguous<R: Runtime, E: CubePrimitive>(
119 client: &ComputeClient<R::Server, R::Channel>,
120 input: &TensorHandleRef<'_, R>,
121) -> TensorHandle<R, E> {
122 let num_elems: usize = input.shape.iter().product();
123 let rank = input.strides.len();
125 let vectorization_factor = tensor_line_size_parallel(
126 R::supported_line_sizes().iter().cloned(),
127 input.shape,
128 input.strides,
129 rank - 1,
130 );
131 let num_vecs = num_elems / vectorization_factor as usize;
132 let approx_sm = 64;
133 let approx_simul_vecs = approx_sm * CubeDim::default().num_elems();
134 let elems_per_unit = match num_vecs as u32 / approx_simul_vecs {
135 0..2 => 1,
136 2..4 => 2,
137 4..8 => 4,
138 8.. => 8,
139 };
140
141 into_contiguous_prefetch(client, input, elems_per_unit, false)
143}
144
145pub fn into_contiguous_pitched<R: Runtime, E: CubePrimitive>(
148 client: &ComputeClient<R::Server, R::Channel>,
149 input: &TensorHandleRef<'_, R>,
150) -> TensorHandle<R, E> {
151 if input.shape.len() <= 1 {
152 return into_contiguous(client, input);
153 }
154
155 let num_elems: usize = input.shape.iter().product();
156 let rank = input.strides.len();
158 let vectorization_factor = tensor_line_size_parallel(
159 R::supported_line_sizes().iter().cloned(),
160 input.shape,
161 input.strides,
162 rank - 1,
163 );
164 let num_vecs = num_elems / vectorization_factor as usize;
165 let approx_sm = 64;
166 let approx_simul_vecs = approx_sm * CubeDim::default().num_elems();
167 let elems_per_unit = match num_vecs as u32 / approx_simul_vecs {
168 0..2 => 1,
169 2..4 => 2,
170 4..8 => 4,
171 8.. => 8,
172 };
173
174 into_contiguous_prefetch(client, input, elems_per_unit, true)
176}
177
178pub fn into_contiguous_prefetch<R: Runtime, E: CubePrimitive>(
180 client: &ComputeClient<R::Server, R::Channel>,
181 input: &TensorHandleRef<'_, R>,
182 mut elems_per_unit: u32,
183 pitched: bool,
184) -> TensorHandle<R, E> {
185 let rank = input.strides.len();
187 let vectorization_factor = tensor_line_size_parallel(
188 R::supported_line_sizes().iter().cloned(),
189 input.shape,
190 input.strides,
191 rank - 1,
192 );
193
194 let num_elems: usize = input.shape.iter().product();
195 let output = if pitched {
196 TensorHandle::empty(client, input.shape.to_vec())
197 } else {
198 let handle = client.empty(num_elems * size_of::<E>());
199 TensorHandle::new_contiguous(input.shape.to_vec(), handle)
200 };
201
202 let mut num_elems_per_unit = vectorization_factor as u32 * elems_per_unit;
203
204 let last_dim = output.shape[rank - 1];
205 let is_padded = rank > 1 && last_dim != output.strides[rank - 2];
206
207 while is_padded && last_dim % num_elems_per_unit as usize != 0 {
209 elems_per_unit /= 2;
210 num_elems_per_unit /= 2;
211 }
212
213 let out_layout = match is_padded {
214 true => StridedLayoutArgs::strided(client, last_dim as u32),
215 false => StridedLayoutArgs::none(),
216 };
217
218 let cube_dim = CubeDim::default();
219 let cube_count =
220 calculate_cube_count_elemwise(num_elems.div_ceil(num_elems_per_unit as usize), cube_dim);
221
222 into_contiguous_kernel::launch::<Line<E>, R>(
223 client,
224 cube_count,
225 cube_dim,
226 input.as_tensor_arg(vectorization_factor),
227 output.as_ref().as_tensor_arg(vectorization_factor),
228 out_layout,
229 Some(rank as u32),
230 elems_per_unit,
231 );
232
233 output
234}