cubecl_linalg/matmul/kernels/tiling2d/
config.rs

1use cubecl_core::{
2    self as cubecl, CubeDim,
3    prelude::{Init, Scope},
4};
5use cubecl_core::{CubeCount, CubeType};
6
7use super::base::TILE_SIZE;
8
9#[derive(Debug, Clone)]
10/// Tiling 2D parameters
11pub struct Tiling2dConfig {
12    /// Block size along dimension of lhs
13    pub block_size_m: usize,
14    /// Block size along common dimension
15    pub block_size_k: usize,
16    /// Block size along dimension of rhs
17    pub block_size_n: usize,
18    /// Tile size and shared memory vectorization
19    pub tile_size: usize,
20    /// Loop unrolling
21    pub unroll: bool,
22}
23
24impl Default for Tiling2dConfig {
25    fn default() -> Self {
26        Self {
27            block_size_m: 64,
28            block_size_k: 32,
29            block_size_n: 64,
30            tile_size: TILE_SIZE,
31            unroll: false,
32        }
33    }
34}
35
36#[derive(Clone, Copy, Hash, PartialEq, Eq, Debug, CubeType)]
37/// Tiling 2D parameters
38pub struct CubeTiling2dConfig {
39    /// Block size along dimension of lhs
40    pub block_size_m: u32,
41    /// Block size along common dimension
42    pub block_size_k: u32,
43    /// Block size along dimension of rhs
44    pub block_size_n: u32,
45    /// Loop unrolling for inner compute loop. Probably slower
46    pub unroll_compute: bool,
47    /// Loop unrolling for all loops related to vectorization/tile size. Probably faster
48    pub unroll_tile: bool,
49    /// Bounds must be checked on lhs dimension
50    pub check_m_bounds: bool,
51    /// Bounds must be checked on common dimension
52    pub check_k_bounds: bool,
53    /// Bounds must be checked on rhs dimension
54    pub check_n_bounds: bool,
55    /// Tile size. Should correspond to vectorization of inputs/outputs/shared memory
56    pub tile_size: u32,
57    /// Lhs is transposed in global memory
58    pub lhs_transposed: bool,
59    /// Rhs is transposed in global memory
60    pub rhs_transposed: bool,
61}
62
63impl Init for CubeTiling2dConfig {
64    fn init(self, _scope: &mut Scope) -> Self {
65        self
66    }
67}
68
69impl CubeTiling2dConfig {
70    pub fn new(
71        config: &Tiling2dConfig,
72        m: usize,
73        k: usize,
74        n: usize,
75        lhs_transposed: bool,
76        rhs_transposed: bool,
77    ) -> Self {
78        assert!(
79            config.block_size_k <= config.block_size_m
80                && config.block_size_k <= config.block_size_n,
81            "Larger block size in k than m or n results in unfilled shared memory."
82        );
83        assert!(
84            config.block_size_m % config.tile_size == 0
85                && config.block_size_k % config.tile_size == 0
86                && config.block_size_n % config.tile_size == 0,
87            "Tiling 2d algorithm assumes tile size divides block size perfectly. "
88        );
89
90        CubeTiling2dConfig {
91            block_size_m: config.block_size_m as u32,
92            block_size_k: config.block_size_k as u32,
93            block_size_n: config.block_size_n as u32,
94            unroll_compute: config.unroll,
95            unroll_tile: true,
96            check_m_bounds: m % config.block_size_m != 0,
97            check_k_bounds: k % config.block_size_k != 0,
98            check_n_bounds: n % config.block_size_n != 0,
99            tile_size: config.tile_size as u32,
100            lhs_transposed,
101            rhs_transposed,
102        }
103    }
104}
105
106pub fn tiling2d_cube_count(output_shape: &[usize], config: &Tiling2dConfig) -> CubeCount {
107    let rank = output_shape.len();
108    let num_rows = *output_shape.get(rank - 2).unwrap();
109    let num_cols = *output_shape.get(rank - 1).unwrap();
110
111    let cubes_x = f32::ceil(num_rows as f32 / config.block_size_m as f32) as u32;
112    let cubes_y = f32::ceil(num_cols as f32 / config.block_size_n as f32) as u32;
113    let mut num_iter = 1;
114    for shape in output_shape.iter().take(rank - 2) {
115        num_iter *= shape;
116    }
117
118    CubeCount::Static(cubes_x, cubes_y, num_iter as u32)
119}
120
121pub fn tiling2d_cube_dim(config: &Tiling2dConfig) -> CubeDim {
122    CubeDim::new(
123        (config.block_size_m / config.tile_size) as u32,
124        (config.block_size_n / config.tile_size) as u32,
125        1,
126    )
127}