cubecl_matmul/components/global/read/layout/
tiled.rs1use cubecl::prelude::*;
2use cubecl_core as cubecl;
3use cubecl_std::tensor::layout::{Coords2d, Layout, LayoutExpand};
4
5use crate::components::{MatrixLayout, stage::StageMemoryConfig};
6
7pub type TiledCoords = (Coords2d, u32);
8
9#[derive(CubeType)]
11pub struct TiledLayout {
12 #[cube(comptime)]
13 config: StageMemoryConfig,
14}
15
16#[cube]
17impl TiledLayout {
18 pub fn new(#[comptime] config: StageMemoryConfig) -> Self {
19 TiledLayout { config }
20 }
21}
22
23#[cube]
24impl Layout for TiledLayout {
25 type Coordinates = TiledCoords;
26 type SourceCoordinates = Coords2d;
27
28 fn to_source_pos(&self, pos: Self::Coordinates) -> Self::SourceCoordinates {
29 let (tile, unit_pos) = pos;
30 let (tile_row, tile_col) = tile;
31
32 let tile_size_row = comptime![self.config.elements_per_tile_along_row];
33 let tile_size_col = comptime![self.config.elements_per_tile_along_col];
34
35 let view_tile_row = tile_row * tile_size_row;
36 let view_tile_col = tile_col * tile_size_col;
37
38 let (unit_row, unit_col) = match comptime![self.config.matrix_layout] {
39 MatrixLayout::RowMajor => (unit_pos / tile_size_col, unit_pos % tile_size_col),
40 MatrixLayout::ColMajor => (unit_pos % tile_size_row, unit_pos / tile_size_row),
41 };
42
43 (view_tile_row + unit_row, view_tile_col + unit_col)
44 }
45
46 fn shape(&self) -> Self::Coordinates {
47 let tile_size_row = comptime![self.config.elements_per_tile_along_row];
48 let tile_size_col = comptime![self.config.elements_per_tile_along_col];
49
50 let tiles_row = comptime![self.config.elements_per_stage_along_row() / tile_size_row];
51 let tiles_col = comptime![self.config.elements_per_stage_along_col() / tile_size_col];
52 let tile_size = tile_size_row * tile_size_col;
53 ((tiles_row, tiles_col), tile_size).runtime()
54 }
55
56 fn is_in_bounds(&self, _pos: Self::Coordinates) -> bool {
57 true.runtime()
59 }
60
61 fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (Self::SourceCoordinates, bool) {
62 (self.to_source_pos(pos), self.is_in_bounds(pos))
63 }
64}