cubecl_linalg/matmul/kernels/tiling2d/
config.rs1use cubecl_core::{
2 self as cubecl, CubeDim,
3 prelude::{Init, Scope},
4};
5use cubecl_core::{CubeCount, CubeType};
6
7use super::base::TILE_SIZE;
8
9#[derive(Debug, Clone)]
10pub struct Tiling2dConfig {
12 pub block_size_m: usize,
14 pub block_size_k: usize,
16 pub block_size_n: usize,
18 pub tile_size: usize,
20 pub unroll: bool,
22}
23
24impl Default for Tiling2dConfig {
25 fn default() -> Self {
26 Self {
27 block_size_m: 64,
28 block_size_k: 32,
29 block_size_n: 64,
30 tile_size: TILE_SIZE,
31 unroll: false,
32 }
33 }
34}
35
36#[derive(Clone, Copy, Hash, PartialEq, Eq, Debug, CubeType)]
37pub struct CubeTiling2dConfig {
39 pub block_size_m: u32,
41 pub block_size_k: u32,
43 pub block_size_n: u32,
45 pub unroll_compute: bool,
47 pub unroll_tile: bool,
49 pub check_m_bounds: bool,
51 pub check_k_bounds: bool,
53 pub check_n_bounds: bool,
55 pub tile_size: u32,
57 pub lhs_transposed: bool,
59 pub rhs_transposed: bool,
61}
62
63impl Init for CubeTiling2dConfig {
64 fn init(self, _scope: &mut Scope) -> Self {
65 self
66 }
67}
68
69impl CubeTiling2dConfig {
70 pub fn new(
71 config: &Tiling2dConfig,
72 m: usize,
73 k: usize,
74 n: usize,
75 lhs_transposed: bool,
76 rhs_transposed: bool,
77 ) -> Self {
78 assert!(
79 config.block_size_k <= config.block_size_m
80 && config.block_size_k <= config.block_size_n,
81 "Larger block size in k than m or n results in unfilled shared memory."
82 );
83 assert!(
84 config.block_size_m % config.tile_size == 0
85 && config.block_size_k % config.tile_size == 0
86 && config.block_size_n % config.tile_size == 0,
87 "Tiling 2d algorithm assumes tile size divides block size perfectly. "
88 );
89
90 CubeTiling2dConfig {
91 block_size_m: config.block_size_m as u32,
92 block_size_k: config.block_size_k as u32,
93 block_size_n: config.block_size_n as u32,
94 unroll_compute: config.unroll,
95 unroll_tile: true,
96 check_m_bounds: m % config.block_size_m != 0,
97 check_k_bounds: k % config.block_size_k != 0,
98 check_n_bounds: n % config.block_size_n != 0,
99 tile_size: config.tile_size as u32,
100 lhs_transposed,
101 rhs_transposed,
102 }
103 }
104}
105
106pub fn tiling2d_cube_count(output_shape: &[usize], config: &Tiling2dConfig) -> CubeCount {
107 let rank = output_shape.len();
108 let num_rows = *output_shape.get(rank - 2).unwrap();
109 let num_cols = *output_shape.get(rank - 1).unwrap();
110
111 let cubes_x = f32::ceil(num_rows as f32 / config.block_size_m as f32) as u32;
112 let cubes_y = f32::ceil(num_cols as f32 / config.block_size_n as f32) as u32;
113 let mut num_iter = 1;
114 for shape in output_shape.iter().take(rank - 2) {
115 num_iter *= shape;
116 }
117
118 CubeCount::Static(cubes_x, cubes_y, num_iter as u32)
119}
120
121pub fn tiling2d_cube_dim(config: &Tiling2dConfig) -> CubeDim {
122 CubeDim::new(
123 (config.block_size_m / config.tile_size) as u32,
124 (config.block_size_n / config.tile_size) as u32,
125 1,
126 )
127}