cubecl_matmul/components/stage/
layout.rs1use std::marker::PhantomData;
2
3use cubecl_core::prelude::*;
4use cubecl_core::{self as cubecl};
5
6use crate::components::tile::Tile;
7use crate::components::{Ident, InputIdent, MatrixLayout};
8
9use super::{StageConfig, StageMemory};
10
11#[cube]
12pub trait TilingOrder: 'static + Send + Sync + Clone + Copy {
15 fn to_row_col<C: StageConfig>(
17 nth: u32,
18 #[comptime] tile_count_rows: u32,
19 #[comptime] tile_count_cols: u32,
20 #[comptime] ident: Ident,
21 #[comptime] config: C,
22 ) -> (u32, u32);
23
24 fn to_nth_tile<C: StageConfig>(
27 row: u32,
28 col: u32,
29 #[comptime] tile_count_rows: u32,
30 #[comptime] tile_count_cols: u32,
31 #[comptime] ident: Ident,
32 #[comptime] config: C,
33 ) -> u32;
34
35 fn to_enum() -> comptime_type!(TilingOrderEnum);
37}
38
39pub enum TilingOrderEnum {
41 RowMajor,
43 ColMajor,
45 Ordered,
48 Tma,
51}
52
53#[derive(CubeType, Clone, Copy)]
54pub struct RowMajorTilingOrder {}
72
73#[derive(CubeType, Clone, Copy)]
74pub struct ColMajorTilingOrder {}
93
94#[derive(CubeType, Clone, Copy)]
95pub struct OrderedTilingOrder {}
120
121#[cube]
122impl TilingOrder for RowMajorTilingOrder {
123 fn to_row_col<C: StageConfig>(
124 nth: u32,
125 #[comptime] _tile_count_rows: u32,
126 #[comptime] tile_count_cols: u32,
127 #[comptime] _ident: Ident,
128 #[comptime] _config: C,
129 ) -> (u32, u32) {
130 (nth / tile_count_cols, nth % tile_count_cols)
131 }
132 fn to_nth_tile<C: StageConfig>(
133 row: u32,
134 col: u32,
135 #[comptime] _tile_count_rows: u32,
136 #[comptime] tile_count_cols: u32,
137 #[comptime] _ident: Ident,
138 #[comptime] _config: C,
139 ) -> u32 {
140 row * tile_count_cols + col
141 }
142
143 fn to_enum() -> comptime_type!(TilingOrderEnum) {
144 TilingOrderEnum::RowMajor
145 }
146}
147
148#[cube]
149impl TilingOrder for ColMajorTilingOrder {
150 fn to_row_col<C: StageConfig>(
151 nth: u32,
152 #[comptime] num_rows: u32,
153 #[comptime] _num_cols: u32,
154 #[comptime] _ident: Ident,
155 #[comptime] _config: C,
156 ) -> (u32, u32) {
157 (nth % num_rows, nth / num_rows)
158 }
159 fn to_nth_tile<C: StageConfig>(
160 row: u32,
161 col: u32,
162 #[comptime] tile_count_rows: u32,
163 #[comptime] _tile_count_cols: u32,
164 #[comptime] _ident: Ident,
165 #[comptime] _config: C,
166 ) -> u32 {
167 col * tile_count_rows + row
168 }
169
170 fn to_enum() -> comptime_type!(TilingOrderEnum) {
171 TilingOrderEnum::ColMajor
172 }
173}
174
175#[cube]
176impl TilingOrder for OrderedTilingOrder {
177 fn to_row_col<C: StageConfig>(
178 nth: u32,
179 #[comptime] tile_count_rows: u32,
180 #[comptime] tile_count_cols: u32,
181 #[comptime] ident: Ident,
182 #[comptime] config: C,
183 ) -> (u32, u32) {
184 if Ident::Lhs != ident {
185 panic!("Ordered tiling order should be used only on Lhs")
186 }
187
188 let group_rows = tile_count_rows / config.num_main_flow_planes();
189 let tiles_per_group = group_rows * tile_count_cols;
190
191 let group = nth / tiles_per_group;
192 let pos_within_group = nth % tiles_per_group;
193
194 let local_row = pos_within_group % group_rows;
195 let row = group * group_rows + local_row;
196 let col = pos_within_group / group_rows;
197
198 (row, col)
199 }
200
201 fn to_nth_tile<C: StageConfig>(
202 row: u32,
203 col: u32,
204 #[comptime] tile_count_rows: u32,
205 #[comptime] tile_count_cols: u32,
206 #[comptime] ident: Ident,
207 #[comptime] config: C,
208 ) -> u32 {
209 if Ident::Lhs != ident {
210 panic!("Ordered tiling order should be used only on Lhs")
211 }
212
213 let group_rows = tile_count_rows / config.num_main_flow_planes();
214 let group = row / group_rows;
215
216 let local_row = row % group_rows;
217 let tiles_per_group = group_rows * tile_count_cols;
218 let pos_within_group = col * group_rows + local_row;
219
220 group * tiles_per_group + pos_within_group
221 }
222
223 fn to_enum() -> comptime_type!(TilingOrderEnum) {
224 TilingOrderEnum::Ordered
225 }
226}
227
228#[cube]
229pub trait TilingLayout: 'static + Send + Sync + Clone + Copy {
231 fn get_tile<ES: Numeric, S: StageConfig>(
233 stage: &StageMemory<ES, Self>,
234 row: u32,
235 col: u32,
236 #[comptime] buffer_index: u32,
237 #[comptime] ident: Ident,
238 #[comptime] config: S,
239 ) -> Tile<ES>;
240}
241
242#[derive(Clone, Copy)]
243pub struct ContiguousTilingLayout<T: TilingOrder> {
246 tiling_order: PhantomData<T>,
247}
248
249#[derive(Clone, Copy)]
250pub struct StridedTilingLayout {}
253
254#[cube]
255impl<T: TilingOrder> ContiguousTilingLayout<T> {
256 pub fn to_x_y<S: StageConfig>(
258 nth: u32,
259 #[comptime] ident: Ident,
260 #[comptime] config: S,
261 ) -> (u32, u32) {
262 let num_x = config.tiling_scheme().tiles_in_stage_row(ident);
263 let num_y = config.tiling_scheme().tiles_in_stage_col(ident);
264
265 T::to_row_col::<S>(nth, num_x, num_y, ident, config)
266 }
267}
268
269#[cube]
270impl<TO: TilingOrder> TilingLayout for ContiguousTilingLayout<TO> {
271 fn get_tile<ES: Numeric, S: StageConfig>(
272 stage_memory: &StageMemory<ES, Self>,
273 row: u32,
274 col: u32,
275 #[comptime] buffer_index: u32,
276 #[comptime] ident: Ident,
277 #[comptime] config: S,
278 ) -> Tile<ES> {
279 let stage_line_size = config.stage_line_size(ident);
280 let tiling_scheme = config.tiling_scheme();
281 let matrix_layout = config.matrix_layout(ident);
282
283 let (row_buffer_offset, col_buffer_offset, total_tile_count_row, total_tile_count_col) =
284 match ident.as_input_ident() {
285 InputIdent::Lhs => {
286 let x_tile_offset = 0;
287 let y_tile_offset = tiling_scheme.tiles_in_stage_col(ident) * buffer_index;
288 let total_tile_count_x = tiling_scheme.tiles_in_stage_row(ident);
289 let total_tile_count_y = tiling_scheme.tiles_in_stage_col(ident)
290 * config.num_stages(InputIdent::Lhs);
291 (
292 x_tile_offset,
293 y_tile_offset,
294 total_tile_count_x,
295 total_tile_count_y,
296 )
297 }
298 InputIdent::Rhs => {
299 let x_tile_offset = tiling_scheme.tiles_in_stage_row(ident) * buffer_index;
300 let y_tile_offset = 0;
301 let total_tile_count_x = tiling_scheme.tiles_in_stage_row(ident)
302 * config.num_stages(InputIdent::Rhs);
303 let total_tile_count_y = tiling_scheme.tiles_in_stage_col(ident);
304 (
305 x_tile_offset,
306 y_tile_offset,
307 total_tile_count_x,
308 total_tile_count_y,
309 )
310 }
311 };
312
313 let (tile_size_x, tile_size_y, tile_slice_length) = match matrix_layout {
314 MatrixLayout::RowMajor => {
315 let tile_size_x = tiling_scheme.elements_in_tile_row(ident);
316 let tile_size_y = tiling_scheme.elements_in_tile_col(ident) / stage_line_size;
317 let stride_x = comptime!(tile_size_y * total_tile_count_col);
318 let length = (tile_size_x - 1) * stride_x + tile_size_y;
319
320 (tile_size_x, tile_size_y, length)
321 }
322 MatrixLayout::ColMajor => {
323 let tile_size_x = tiling_scheme.elements_in_tile_row(ident) / stage_line_size;
324 let tile_size_y = tiling_scheme.elements_in_tile_col(ident);
325 let stride_y = comptime!(tile_size_x * total_tile_count_row);
326 let length = (tile_size_y - 1) * stride_y + tile_size_x;
327
328 (tile_size_x, tile_size_y, length)
329 }
330 };
331
332 let start = tile_size_x
333 * tile_size_y
334 * TO::to_nth_tile::<S>(
335 row + row_buffer_offset,
336 col + col_buffer_offset,
337 total_tile_count_row,
338 total_tile_count_col,
339 ident,
340 config,
341 );
342
343 Tile::new_contiguous::<S::TileConfig>(
344 stage_memory
345 .as_slice(stage_line_size)
346 .slice(start, start + tile_slice_length),
347 ident,
348 config.tile_config(),
349 )
350 }
351}
352
353#[cube]
354impl StridedTilingLayout {
355 pub fn nth_slice<ES: Numeric, S: StageConfig>(
357 stage: &mut StageMemory<ES, Self>,
358 nth: u32,
359 #[comptime] ident: Ident,
360 #[comptime] config: S,
361 ) -> SliceMut<Line<ES>> {
362 let matrix_layout = config.matrix_layout(ident);
363 let stage_line_size = config.stage_line_size(ident);
364
365 let slice_length = match comptime!(matrix_layout) {
366 MatrixLayout::RowMajor => config.tiling_scheme().elements_in_stage_col(ident),
367 MatrixLayout::ColMajor => config.tiling_scheme().elements_in_stage_row(ident),
368 } / stage_line_size;
369
370 let start = slice_length * nth;
371 stage
372 .as_slice_mut(stage_line_size)
373 .slice_mut(start, start + slice_length)
374 }
375}
376
377#[cube]
378impl TilingLayout for StridedTilingLayout {
379 fn get_tile<ES: Numeric, S: StageConfig>(
380 stage: &StageMemory<ES, Self>,
381 x: u32,
382 y: u32,
383 #[comptime] _buffer_index: u32,
384 #[comptime] ident: Ident,
385 #[comptime] config: S,
386 ) -> Tile<ES> {
387 if comptime!(config.num_stages(ident.as_input_ident()) > 1) {
388 unimplemented!()
389 }
390
391 let stage_line_size = config.stage_line_size(ident);
392 let tiling_scheme = config.tiling_scheme();
393 let matrix_layout = config.matrix_layout(ident);
394
395 let tile_count_x = tiling_scheme.tiles_in_stage_row(ident);
396 let tile_count_y = tiling_scheme.tiles_in_stage_col(ident);
397
398 match matrix_layout {
399 MatrixLayout::RowMajor => {
400 let tile_size_x = tiling_scheme.elements_in_tile_row(ident);
401 let tile_size_y = tiling_scheme.elements_in_tile_col(ident) / stage_line_size;
402
403 let stride = tile_count_y * tile_size_y;
404 let length = (tile_size_x - 1) * stride + tile_size_y;
405 let start = x * tile_size_x * stride + y * tile_size_y;
406
407 Tile::new_strided(
408 stage.as_slice(stage_line_size).slice(start, start + length),
409 stride,
410 matrix_layout,
411 )
412 }
413 MatrixLayout::ColMajor => {
414 let tile_size_x = tiling_scheme.elements_in_tile_row(ident) / stage_line_size;
415 let tile_size_y = tiling_scheme.elements_in_tile_col(ident);
416
417 let stride = tile_count_x * tile_size_x;
418 let length = (tile_size_y - 1) * stride + tile_size_x;
419 let start = x * tile_size_x + y * tile_size_y * stride;
420
421 Tile::new_strided(
422 stage.as_slice(stage_line_size).slice(start, start + length),
423 stride,
424 matrix_layout,
425 )
426 }
427 }
428 }
429}