cubecl_matmul/components/
tiling_scheme.rs

1use super::Ident;
2use super::size::{GlobalPartitionSize, MatmulDim, PartitionSize, StageSize, TileSize};
3
4#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)]
5/// Complete tiling configuration for a matmul.
6/// Encodes all structural information needed to compute tiling shapes and counts.
7pub struct TilingScheme {
8    pub tile_size: TileSize,
9    pub partition_size: PartitionSize,
10    pub stage_size: StageSize,
11    pub global_partition_size: GlobalPartitionSize,
12}
13
14impl TilingScheme {
15    /// Create a builder for TilingScheme
16    pub fn builder() -> TilingSchemeBuilder {
17        TilingSchemeBuilder::default()
18    }
19}
20
21#[derive(Debug, Default)]
22/// Builder for [`TilingScheme`]. Allows step-by-step configuration.
23pub struct TilingSchemeBuilder {
24    tile_size: Option<TileSize>,
25    partition_size: Option<PartitionSize>,
26    stage_size: Option<StageSize>,
27    global_partition_size: Option<GlobalPartitionSize>,
28}
29
30impl TilingSchemeBuilder {
31    /// Specify tile size for tiling scheme
32    pub fn with_tile_size(mut self, tile_size: TileSize) -> Self {
33        self.tile_size = Some(tile_size);
34        self
35    }
36
37    /// Specify partition size for tiling scheme
38    pub fn with_partition_size(mut self, partition_size: PartitionSize) -> Self {
39        self.partition_size = Some(partition_size);
40        self
41    }
42
43    /// Specify stage size for tiling scheme
44    ///
45    /// Only stage size k = 1 is supported
46    pub fn with_stage_size(mut self, stage_size: StageSize) -> Self {
47        assert!(stage_size.k == 1, "Stage size k > 1 is not supported");
48        self.stage_size = Some(stage_size);
49        self
50    }
51
52    /// Optional: specify global partition size for tiling scheme
53    ///
54    /// If not specified, will default to (1, 1, 1)
55    pub fn with_global_partition_size(
56        mut self,
57        global_partition_size: GlobalPartitionSize,
58    ) -> Self {
59        self.global_partition_size = Some(global_partition_size);
60        self
61    }
62
63    /// Finish building
64    pub fn build(self) -> Result<TilingScheme, &'static str> {
65        Ok(TilingScheme {
66            tile_size: self.tile_size.ok_or("Missing tile_size")?,
67            partition_size: self.partition_size.ok_or("Missing tiles_per_partition")?,
68            stage_size: self.stage_size.ok_or("Missing partitions_per_stage")?,
69            global_partition_size: self
70                .global_partition_size
71                .unwrap_or(GlobalPartitionSize::new(1, 1, 1)),
72        })
73    }
74}
75
76#[derive(Debug, Clone, Copy, Eq, PartialEq)]
77enum TilingLevel {
78    GlobalPartition,
79    Stage,
80    StagePartition,
81    Tile,
82    Element,
83}
84
85impl TilingScheme {
86    fn try_count_1d(
87        &self,
88        child_level: TilingLevel,
89        parent_level: TilingLevel,
90        dim: MatmulDim,
91    ) -> Option<u32> {
92        use TilingLevel::*;
93
94        match (child_level, parent_level) {
95            (child, parent) if child == parent => Some(1),
96
97            (Stage, GlobalPartition) => match dim {
98                MatmulDim::M => Some(self.global_partition_size.m),
99                MatmulDim::N => Some(self.global_partition_size.n),
100                MatmulDim::K => None,
101            },
102
103            (StagePartition, Stage) => Some(self.stage_size.get(dim)),
104
105            (Tile, StagePartition) => Some(self.partition_size.get(dim)),
106
107            (Element, Tile) => Some(self.tile_size.get(dim)),
108
109            (StagePartition, GlobalPartition) => Some(
110                self.try_count_1d(StagePartition, Stage, dim)?
111                    * self.try_count_1d(Stage, GlobalPartition, dim)?,
112            ),
113
114            (Tile, GlobalPartition) => Some(
115                self.try_count_1d(Tile, Stage, dim)?
116                    * self.try_count_1d(Stage, GlobalPartition, dim)?,
117            ),
118
119            (Element, GlobalPartition) => Some(
120                self.try_count_1d(Element, Stage, dim)?
121                    * self.try_count_1d(Stage, GlobalPartition, dim)?,
122            ),
123
124            (Tile, Stage) => Some(
125                self.try_count_1d(StagePartition, Stage, dim)?
126                    * self.try_count_1d(Tile, StagePartition, dim)?,
127            ),
128
129            (Element, Stage) => {
130                Some(self.try_count_1d(Tile, Stage, dim)? * self.try_count_1d(Element, Tile, dim)?)
131            }
132
133            (Element, StagePartition) => Some(
134                self.try_count_1d(Tile, StagePartition, dim)?
135                    * self.try_count_1d(Element, Tile, dim)?,
136            ),
137
138            // Invalid transitions
139            _ => None,
140        }
141    }
142
143    fn try_count_2d(
144        &self,
145        child_level: TilingLevel,
146        parent_level: TilingLevel,
147        dim1: MatmulDim,
148        dim2: MatmulDim,
149    ) -> Option<u32> {
150        Some(
151            self.try_count_1d(child_level, parent_level, dim1)?
152                * self.try_count_1d(child_level, parent_level, dim2)?,
153        )
154    }
155
156    fn count_1d(&self, child_level: TilingLevel, parent_level: TilingLevel, dim: MatmulDim) -> u32 {
157        self.try_count_1d(child_level, parent_level, dim)
158            .unwrap_or_else(|| {
159                panic!("Invalid hierarchy: {parent_level:?} cannot contain {child_level:?}")
160            })
161    }
162
163    fn count_1d_ident_row<I: Into<Ident>>(
164        &self,
165        child_level: TilingLevel,
166        parent_level: TilingLevel,
167        ident: I,
168    ) -> u32 {
169        match ident.into() {
170            Ident::Lhs => self.count_1d(child_level, parent_level, MatmulDim::M),
171            Ident::Rhs => self.count_1d(child_level, parent_level, MatmulDim::K),
172            Ident::Out => self.count_1d(child_level, parent_level, MatmulDim::M),
173        }
174    }
175
176    fn count_1d_ident_col<I: Into<Ident>>(
177        &self,
178        child_level: TilingLevel,
179        parent_level: TilingLevel,
180        ident: I,
181    ) -> u32 {
182        match ident.into() {
183            Ident::Lhs => self.count_1d(child_level, parent_level, MatmulDim::K),
184            Ident::Rhs => self.count_1d(child_level, parent_level, MatmulDim::N),
185            Ident::Out => self.count_1d(child_level, parent_level, MatmulDim::N),
186        }
187    }
188
189    fn count_2d(
190        &self,
191        child_level: TilingLevel,
192        parent_level: TilingLevel,
193        dim1: MatmulDim,
194        dim2: MatmulDim,
195    ) -> u32 {
196        self.try_count_2d(child_level, parent_level, dim1, dim2)
197            .unwrap_or_else(|| {
198                panic!("Invalid hierarchy: {parent_level:?} cannot contain {child_level:?}")
199            })
200    }
201
202    fn count_2d_ident<I: Into<Ident>>(
203        &self,
204        child_level: TilingLevel,
205        parent_level: TilingLevel,
206        ident: I,
207    ) -> u32 {
208        match ident.into() {
209            Ident::Lhs => self.count_2d(child_level, parent_level, MatmulDim::M, MatmulDim::K),
210            Ident::Rhs => self.count_2d(child_level, parent_level, MatmulDim::K, MatmulDim::N),
211            Ident::Out => self.count_2d(child_level, parent_level, MatmulDim::M, MatmulDim::N),
212        }
213    }
214}
215
216macro_rules! count_1d_method {
217    ($name:ident, $child:ident, $parent:ident, $dim:ident) => {
218        pub fn $name(&self) -> u32 {
219            self.count_1d(TilingLevel::$child, TilingLevel::$parent, MatmulDim::$dim)
220        }
221    };
222}
223
224macro_rules! count_1d_ident_row_method {
225    ($name:ident, $child:ident, $parent:ident) => {
226        pub fn $name<I: Into<Ident>>(&self, ident: I) -> u32 {
227            self.count_1d_ident_row(TilingLevel::$child, TilingLevel::$parent, ident)
228        }
229    };
230}
231
232macro_rules! count_1d_ident_col_method {
233    ($name:ident, $child:ident, $parent:ident) => {
234        pub fn $name<I: Into<Ident>>(&self, ident: I) -> u32 {
235            self.count_1d_ident_col(TilingLevel::$child, TilingLevel::$parent, ident)
236        }
237    };
238}
239
240macro_rules! count_2d_method {
241    ($name:ident, $child:ident, $parent:ident, $dim1:ident, $dim2:ident) => {
242        pub fn $name(&self) -> u32 {
243            self.count_2d(
244                TilingLevel::$child,
245                TilingLevel::$parent,
246                MatmulDim::$dim1,
247                MatmulDim::$dim2,
248            )
249        }
250    };
251}
252
253macro_rules! count_2d_ident_method {
254    ($name:ident, $child:ident, $parent:ident) => {
255        pub fn $name<I: Into<Ident>>(&self, ident: I) -> u32 {
256            self.count_2d_ident(TilingLevel::$child, TilingLevel::$parent, ident)
257        }
258    };
259}
260
261impl TilingScheme {
262    count_1d_method!(stage_partitions_in_stage_m, StagePartition, Stage, M);
263    count_1d_method!(stage_partitions_in_stage_n, StagePartition, Stage, N);
264    count_1d_method!(stage_partitions_in_stage_k, StagePartition, Stage, K);
265    count_1d_ident_row_method!(stage_partitions_in_stage_row, StagePartition, Stage);
266    count_1d_ident_col_method!(stage_partitions_in_stage_col, StagePartition, Stage);
267    count_2d_method!(stage_partitions_in_stage_mk, StagePartition, Stage, M, K);
268    count_2d_method!(stage_partitions_in_stage_nk, StagePartition, Stage, N, K);
269    count_2d_method!(stage_partitions_in_stage_mn, StagePartition, Stage, M, N);
270    count_2d_ident_method!(stage_partitions_in_stage, StagePartition, Stage);
271
272    count_1d_method!(tiles_in_stage_m, Tile, Stage, M);
273    count_1d_method!(tiles_in_stage_n, Tile, Stage, N);
274    count_1d_method!(tiles_in_stage_k, Tile, Stage, K);
275    count_1d_ident_row_method!(tiles_in_stage_row, Tile, Stage);
276    count_1d_ident_col_method!(tiles_in_stage_col, Tile, Stage);
277    count_2d_method!(tiles_in_stage_mk, Tile, Stage, M, K);
278    count_2d_method!(tiles_in_stage_nk, Tile, Stage, N, K);
279    count_2d_method!(tiles_in_stage_mn, Tile, Stage, M, N);
280    count_2d_ident_method!(tiles_in_stage, Tile, Stage);
281
282    count_1d_method!(elements_in_stage_m, Element, Stage, M);
283    count_1d_method!(elements_in_stage_n, Element, Stage, N);
284    count_1d_method!(elements_in_stage_k, Element, Stage, K);
285    count_1d_ident_row_method!(elements_in_stage_row, Element, Stage);
286    count_1d_ident_col_method!(elements_in_stage_col, Element, Stage);
287    count_2d_method!(elements_in_stage_mk, Element, Stage, M, K);
288    count_2d_method!(elements_in_stage_nk, Element, Stage, N, K);
289    count_2d_method!(elements_in_stage_mn, Element, Stage, M, N);
290    count_2d_ident_method!(elements_in_stage, Element, Stage);
291
292    count_1d_method!(tiles_in_stage_partition_m, Tile, StagePartition, M);
293    count_1d_method!(tiles_in_stage_partition_n, Tile, StagePartition, N);
294    count_1d_method!(tiles_in_stage_partition_k, Tile, StagePartition, K);
295    count_1d_ident_row_method!(tiles_in_stage_partition_row, Tile, StagePartition);
296    count_1d_ident_col_method!(tiles_in_stage_partition_col, Tile, StagePartition);
297    count_2d_method!(tiles_in_stage_partition_mk, Tile, StagePartition, M, K);
298    count_2d_method!(tiles_in_stage_partition_nk, Tile, StagePartition, N, K);
299    count_2d_method!(tiles_in_stage_partition_mn, Tile, StagePartition, M, N);
300    count_2d_ident_method!(tiles_in_stage_partition, Tile, StagePartition);
301
302    count_1d_method!(elements_in_stage_partition_m, Element, StagePartition, M);
303    count_1d_method!(elements_in_stage_partition_n, Element, StagePartition, N);
304    count_1d_method!(elements_in_stage_partition_k, Element, StagePartition, K);
305    count_1d_ident_row_method!(elements_in_stage_partition_row, Element, StagePartition);
306    count_1d_ident_col_method!(elements_in_stage_partition_col, Element, StagePartition);
307    count_2d_method!(
308        elements_in_stage_partition_mk,
309        Element,
310        StagePartition,
311        M,
312        K
313    );
314    count_2d_method!(
315        elements_in_stage_partition_nk,
316        Element,
317        StagePartition,
318        N,
319        K
320    );
321    count_2d_method!(
322        elements_in_stage_partition_mn,
323        Element,
324        StagePartition,
325        M,
326        N
327    );
328    count_2d_ident_method!(elements_in_stage_partition, Element, StagePartition);
329
330    count_1d_method!(elements_in_tile_m, Element, Tile, M);
331    count_1d_method!(elements_in_tile_n, Element, Tile, N);
332    count_1d_method!(elements_in_tile_k, Element, Tile, K);
333    count_1d_ident_row_method!(elements_in_tile_row, Element, Tile);
334    count_1d_ident_col_method!(elements_in_tile_col, Element, Tile);
335    count_2d_method!(elements_in_tile_mk, Element, Tile, M, K);
336    count_2d_method!(elements_in_tile_nk, Element, Tile, N, K);
337    count_2d_method!(elements_in_tile_mn, Element, Tile, M, N);
338    count_2d_ident_method!(elements_in_tile, Element, Tile);
339
340    count_1d_method!(elements_in_global_partition_m, Element, GlobalPartition, M);
341    count_1d_method!(elements_in_global_partition_n, Element, GlobalPartition, N);
342    count_1d_method!(tiles_in_global_partition_m, Tile, GlobalPartition, M);
343    count_1d_method!(tiles_in_global_partition_n, Tile, GlobalPartition, N);
344    count_1d_method!(
345        stage_partitions_in_global_partition_m,
346        StagePartition,
347        GlobalPartition,
348        M
349    );
350    count_1d_method!(
351        stage_partitions_in_global_partition_n,
352        StagePartition,
353        GlobalPartition,
354        N
355    );
356    count_1d_method!(stages_in_global_partition_m, Stage, GlobalPartition, M);
357    count_1d_method!(stages_in_global_partition_n, Stage, GlobalPartition, N);
358}