cubecl_matmul/components/global/read/layout/
tiled.rs

1use cubecl::prelude::*;
2use cubecl_core as cubecl;
3use cubecl_std::tensor::layout::{Coords2d, Layout, LayoutExpand};
4
5use crate::components::{MatrixLayout, global::memory::GlobalMemoryConfig};
6
7pub type TiledCoords = (Coords2d, u32);
8
9/// Tiling mapping on a 2D layout. Unit offset is translated to a 2D offset within the tile.
10#[derive(CubeType)]
11pub struct TiledLayout {
12    #[cube(comptime)]
13    config: GlobalMemoryConfig,
14}
15
16#[cube]
17impl TiledLayout {
18    pub fn new(#[comptime] config: GlobalMemoryConfig) -> Self {
19        TiledLayout { config }
20    }
21}
22
23#[cube]
24impl Layout for TiledLayout {
25    type Coordinates = TiledCoords;
26    type SourceCoordinates = Coords2d;
27
28    fn to_source_pos(&self, pos: Self::Coordinates) -> Self::SourceCoordinates {
29        let (tile, unit_pos) = pos;
30        let (tile_row, tile_col) = tile;
31
32        let tile_size_row = comptime![self.config.elements_in_tile_row];
33        let tile_size_col = comptime![self.config.elements_in_tile_col];
34
35        let view_tile_row = tile_row * tile_size_row;
36        let view_tile_col = tile_col * tile_size_col;
37
38        let (unit_row, unit_col) = match comptime![self.config.matrix_layout] {
39            MatrixLayout::RowMajor => (unit_pos / tile_size_col, unit_pos % tile_size_col),
40            MatrixLayout::ColMajor => (unit_pos % tile_size_row, unit_pos / tile_size_row),
41        };
42
43        (view_tile_row + unit_row, view_tile_col + unit_col)
44    }
45
46    fn shape(&self) -> Self::Coordinates {
47        let tile_size_row = comptime![self.config.elements_in_tile_row];
48        let tile_size_col = comptime![self.config.elements_in_tile_col];
49
50        let tiles_row = comptime![self.config.elements_in_stage_row / tile_size_row];
51        let tiles_col = comptime![self.config.elements_in_stage_col / tile_size_col];
52        let tile_size = tile_size_row * tile_size_col;
53        ((tiles_row, tiles_col), tile_size).runtime()
54    }
55
56    fn is_in_bounds(&self, _pos: Self::Coordinates) -> bool {
57        // Bounds checking should be handled by underlying layout
58        true.runtime()
59    }
60
61    fn to_source_pos_checked(&self, pos: Self::Coordinates) -> (Self::SourceCoordinates, bool) {
62        (self.to_source_pos(pos), self.is_in_bounds(pos))
63    }
64}