cubecl_matmul/components/
tiling_scheme.rs

1use super::StageIdent;
2use super::size::{GlobalPartitionSize, MatmulDim, PartitionSize, StageSize, TileSize};
3
4#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)]
5/// Complete tiling configuration for a matmul.
6/// Encodes all structural information needed to compute tiling shapes and counts.
7pub struct TilingScheme {
8    pub tile_size: TileSize,
9    pub partition_size: PartitionSize,
10    pub stage_size: StageSize,
11    pub global_partition_size: GlobalPartitionSize,
12}
13
14impl TilingScheme {
15    /// Create a builder for TilingScheme
16    pub fn builder() -> TilingSchemeBuilder {
17        TilingSchemeBuilder::default()
18    }
19}
20
21#[derive(Debug, Default)]
22/// Builder for [`TilingScheme`]. Allows step-by-step configuration.
23pub struct TilingSchemeBuilder {
24    tile_size: Option<TileSize>,
25    partition_size: Option<PartitionSize>,
26    stage_size: Option<StageSize>,
27    global_partition_size: Option<GlobalPartitionSize>,
28}
29
30impl TilingSchemeBuilder {
31    /// Specify tile size for tiling scheme
32    pub fn with_tile_size(mut self, tile_size: TileSize) -> Self {
33        self.tile_size = Some(tile_size);
34        self
35    }
36
37    /// Specify partition size for tiling scheme
38    pub fn with_partition_size(mut self, partition_size: PartitionSize) -> Self {
39        self.partition_size = Some(partition_size);
40        self
41    }
42
43    /// Specify stage size for tiling scheme
44    ///
45    /// Only stage size k = 1 is supported
46    pub fn with_stage_size(mut self, stage_size: StageSize) -> Self {
47        assert!(stage_size.k == 1, "Stage size k > 1 is not supported");
48        self.stage_size = Some(stage_size);
49        self
50    }
51
52    /// Optional: specify global partition size for tiling scheme
53    ///
54    /// If not specified, will default to (1, 1, 1)
55    pub fn with_global_partition_size(
56        mut self,
57        global_partition_size: GlobalPartitionSize,
58    ) -> Self {
59        self.global_partition_size = Some(global_partition_size);
60        self
61    }
62
63    /// Finish building
64    pub fn build(self) -> Result<TilingScheme, &'static str> {
65        Ok(TilingScheme {
66            tile_size: self.tile_size.ok_or("Missing tile_size")?,
67            partition_size: self.partition_size.ok_or("Missing tiles_per_partition")?,
68            stage_size: self.stage_size.ok_or("Missing partitions_per_stage")?,
69            global_partition_size: self
70                .global_partition_size
71                .unwrap_or(GlobalPartitionSize::new(1, 1, 1)),
72        })
73    }
74}
75
76#[derive(Debug, Clone, Copy, Eq, PartialEq)]
77enum TilingLevel {
78    GlobalPartition,
79    Stage,
80    StagePartition,
81    Tile,
82    Element,
83}
84
85impl TilingScheme {
86    fn try_count_1d(
87        &self,
88        child_level: TilingLevel,
89        parent_level: TilingLevel,
90        dim: MatmulDim,
91    ) -> Option<u32> {
92        use TilingLevel::*;
93
94        match (child_level, parent_level) {
95            (child, parent) if child == parent => Some(1),
96
97            (Stage, GlobalPartition) => match dim {
98                MatmulDim::M => Some(self.global_partition_size.m),
99                MatmulDim::N => Some(self.global_partition_size.n),
100                MatmulDim::K => None,
101            },
102
103            (StagePartition, Stage) => Some(self.stage_size.get(dim)),
104
105            (Tile, StagePartition) => Some(self.partition_size.get(dim)),
106
107            (Element, Tile) => Some(self.tile_size.get(dim)),
108
109            (StagePartition, GlobalPartition) => Some(
110                self.try_count_1d(StagePartition, Stage, dim)?
111                    * self.try_count_1d(Stage, GlobalPartition, dim)?,
112            ),
113
114            (Tile, GlobalPartition) => Some(
115                self.try_count_1d(Tile, Stage, dim)?
116                    * self.try_count_1d(Stage, GlobalPartition, dim)?,
117            ),
118
119            (Element, GlobalPartition) => Some(
120                self.try_count_1d(Element, Stage, dim)?
121                    * self.try_count_1d(Stage, GlobalPartition, dim)?,
122            ),
123
124            (Tile, Stage) => Some(
125                self.try_count_1d(StagePartition, Stage, dim)?
126                    * self.try_count_1d(Tile, StagePartition, dim)?,
127            ),
128
129            (Element, Stage) => {
130                Some(self.try_count_1d(Tile, Stage, dim)? * self.try_count_1d(Element, Tile, dim)?)
131            }
132
133            (Element, StagePartition) => Some(
134                self.try_count_1d(Tile, StagePartition, dim)?
135                    * self.try_count_1d(Element, Tile, dim)?,
136            ),
137
138            // Invalid transitions
139            _ => None,
140        }
141    }
142
143    fn try_count_2d(
144        &self,
145        child_level: TilingLevel,
146        parent_level: TilingLevel,
147        dim1: MatmulDim,
148        dim2: MatmulDim,
149    ) -> Option<u32> {
150        Some(
151            self.try_count_1d(child_level, parent_level, dim1)?
152                * self.try_count_1d(child_level, parent_level, dim2)?,
153        )
154    }
155
156    fn count_1d(&self, child_level: TilingLevel, parent_level: TilingLevel, dim: MatmulDim) -> u32 {
157        self.try_count_1d(child_level, parent_level, dim)
158            .unwrap_or_else(|| {
159                panic!("Invalid hierarchy: {parent_level:?} cannot contain {child_level:?}")
160            })
161    }
162
163    fn count_1d_ident_row(
164        &self,
165        child_level: TilingLevel,
166        parent_level: TilingLevel,
167        ident: StageIdent,
168    ) -> u32 {
169        match ident {
170            StageIdent::Lhs => self.count_1d(child_level, parent_level, MatmulDim::M),
171            StageIdent::Rhs => self.count_1d(child_level, parent_level, MatmulDim::K),
172            StageIdent::Acc => self.count_1d(child_level, parent_level, MatmulDim::M),
173            StageIdent::Out => self.count_1d(child_level, parent_level, MatmulDim::M),
174        }
175    }
176
177    fn count_1d_ident_col(
178        &self,
179        child_level: TilingLevel,
180        parent_level: TilingLevel,
181        ident: StageIdent,
182    ) -> u32 {
183        match ident {
184            StageIdent::Lhs => self.count_1d(child_level, parent_level, MatmulDim::K),
185            StageIdent::Rhs => self.count_1d(child_level, parent_level, MatmulDim::N),
186            StageIdent::Acc => self.count_1d(child_level, parent_level, MatmulDim::N),
187            StageIdent::Out => self.count_1d(child_level, parent_level, MatmulDim::N),
188        }
189    }
190
191    fn count_2d(
192        &self,
193        child_level: TilingLevel,
194        parent_level: TilingLevel,
195        dim1: MatmulDim,
196        dim2: MatmulDim,
197    ) -> u32 {
198        self.try_count_2d(child_level, parent_level, dim1, dim2)
199            .unwrap_or_else(|| {
200                panic!("Invalid hierarchy: {parent_level:?} cannot contain {child_level:?}")
201            })
202    }
203
204    fn count_2d_ident(
205        &self,
206        child_level: TilingLevel,
207        parent_level: TilingLevel,
208        ident: StageIdent,
209    ) -> u32 {
210        match ident {
211            StageIdent::Lhs => self.count_2d(child_level, parent_level, MatmulDim::M, MatmulDim::K),
212            StageIdent::Rhs => self.count_2d(child_level, parent_level, MatmulDim::K, MatmulDim::N),
213            StageIdent::Acc => self.count_2d(child_level, parent_level, MatmulDim::M, MatmulDim::N),
214            StageIdent::Out => self.count_2d(child_level, parent_level, MatmulDim::M, MatmulDim::N),
215        }
216    }
217}
218
219macro_rules! count_1d_method {
220    ($name:ident, $child:ident, $parent:ident, $dim:ident) => {
221        pub fn $name(&self) -> u32 {
222            self.count_1d(TilingLevel::$child, TilingLevel::$parent, MatmulDim::$dim)
223        }
224    };
225}
226
227macro_rules! count_1d_ident_row_method {
228    ($name:ident, $child:ident, $parent:ident) => {
229        pub fn $name<I: Into<StageIdent>>(&self, ident: I) -> u32 {
230            self.count_1d_ident_row(TilingLevel::$child, TilingLevel::$parent, ident.into())
231        }
232    };
233}
234
235macro_rules! count_1d_ident_col_method {
236    ($name:ident, $child:ident, $parent:ident) => {
237        pub fn $name<I: Into<StageIdent>>(&self, ident: I) -> u32 {
238            self.count_1d_ident_col(TilingLevel::$child, TilingLevel::$parent, ident.into())
239        }
240    };
241}
242
243macro_rules! count_2d_method {
244    ($name:ident, $child:ident, $parent:ident, $dim1:ident, $dim2:ident) => {
245        pub fn $name(&self) -> u32 {
246            self.count_2d(
247                TilingLevel::$child,
248                TilingLevel::$parent,
249                MatmulDim::$dim1,
250                MatmulDim::$dim2,
251            )
252        }
253    };
254}
255
256macro_rules! count_2d_ident_method {
257    ($name:ident, $child:ident, $parent:ident) => {
258        pub fn $name<I: Into<StageIdent>>(&self, ident: I) -> u32 {
259            self.count_2d_ident(TilingLevel::$child, TilingLevel::$parent, ident.into())
260        }
261    };
262}
263
264impl TilingScheme {
265    count_1d_method!(stage_partitions_in_stage_m, StagePartition, Stage, M);
266    count_1d_method!(stage_partitions_in_stage_n, StagePartition, Stage, N);
267    count_1d_method!(stage_partitions_in_stage_k, StagePartition, Stage, K);
268    count_1d_ident_row_method!(stage_partitions_in_stage_row, StagePartition, Stage);
269    count_1d_ident_col_method!(stage_partitions_in_stage_col, StagePartition, Stage);
270    count_2d_method!(stage_partitions_in_stage_mk, StagePartition, Stage, M, K);
271    count_2d_method!(stage_partitions_in_stage_nk, StagePartition, Stage, N, K);
272    count_2d_method!(stage_partitions_in_stage_mn, StagePartition, Stage, M, N);
273    count_2d_ident_method!(stage_partitions_in_stage, StagePartition, Stage);
274
275    count_1d_method!(tiles_in_stage_m, Tile, Stage, M);
276    count_1d_method!(tiles_in_stage_n, Tile, Stage, N);
277    count_1d_method!(tiles_in_stage_k, Tile, Stage, K);
278    count_1d_ident_row_method!(tiles_in_stage_row, Tile, Stage);
279    count_1d_ident_col_method!(tiles_in_stage_col, Tile, Stage);
280    count_2d_method!(tiles_in_stage_mk, Tile, Stage, M, K);
281    count_2d_method!(tiles_in_stage_nk, Tile, Stage, N, K);
282    count_2d_method!(tiles_in_stage_mn, Tile, Stage, M, N);
283    count_2d_ident_method!(tiles_in_stage, Tile, Stage);
284
285    count_1d_method!(elements_in_stage_m, Element, Stage, M);
286    count_1d_method!(elements_in_stage_n, Element, Stage, N);
287    count_1d_method!(elements_in_stage_k, Element, Stage, K);
288    count_1d_ident_row_method!(elements_in_stage_row, Element, Stage);
289    count_1d_ident_col_method!(elements_in_stage_col, Element, Stage);
290    count_2d_method!(elements_in_stage_mk, Element, Stage, M, K);
291    count_2d_method!(elements_in_stage_nk, Element, Stage, N, K);
292    count_2d_method!(elements_in_stage_mn, Element, Stage, M, N);
293    count_2d_ident_method!(elements_in_stage, Element, Stage);
294
295    count_1d_method!(tiles_in_stage_partition_m, Tile, StagePartition, M);
296    count_1d_method!(tiles_in_stage_partition_n, Tile, StagePartition, N);
297    count_1d_method!(tiles_in_stage_partition_k, Tile, StagePartition, K);
298    count_1d_ident_row_method!(tiles_in_stage_partition_row, Tile, StagePartition);
299    count_1d_ident_col_method!(tiles_in_stage_partition_col, Tile, StagePartition);
300    count_2d_method!(tiles_in_stage_partition_mk, Tile, StagePartition, M, K);
301    count_2d_method!(tiles_in_stage_partition_nk, Tile, StagePartition, N, K);
302    count_2d_method!(tiles_in_stage_partition_mn, Tile, StagePartition, M, N);
303    count_2d_ident_method!(tiles_in_stage_partition, Tile, StagePartition);
304
305    count_1d_method!(elements_in_stage_partition_m, Element, StagePartition, M);
306    count_1d_method!(elements_in_stage_partition_n, Element, StagePartition, N);
307    count_1d_method!(elements_in_stage_partition_k, Element, StagePartition, K);
308    count_1d_ident_row_method!(elements_in_stage_partition_row, Element, StagePartition);
309    count_1d_ident_col_method!(elements_in_stage_partition_col, Element, StagePartition);
310    count_2d_method!(
311        elements_in_stage_partition_mk,
312        Element,
313        StagePartition,
314        M,
315        K
316    );
317    count_2d_method!(
318        elements_in_stage_partition_nk,
319        Element,
320        StagePartition,
321        N,
322        K
323    );
324    count_2d_method!(
325        elements_in_stage_partition_mn,
326        Element,
327        StagePartition,
328        M,
329        N
330    );
331    count_2d_ident_method!(elements_in_stage_partition, Element, StagePartition);
332
333    count_1d_method!(elements_in_tile_m, Element, Tile, M);
334    count_1d_method!(elements_in_tile_n, Element, Tile, N);
335    count_1d_method!(elements_in_tile_k, Element, Tile, K);
336    count_1d_ident_row_method!(elements_in_tile_row, Element, Tile);
337    count_1d_ident_col_method!(elements_in_tile_col, Element, Tile);
338    count_2d_method!(elements_in_tile_mk, Element, Tile, M, K);
339    count_2d_method!(elements_in_tile_nk, Element, Tile, N, K);
340    count_2d_method!(elements_in_tile_mn, Element, Tile, M, N);
341    count_2d_ident_method!(elements_in_tile, Element, Tile);
342
343    count_1d_method!(elements_in_global_partition_m, Element, GlobalPartition, M);
344    count_1d_method!(elements_in_global_partition_n, Element, GlobalPartition, N);
345    count_1d_method!(tiles_in_global_partition_m, Tile, GlobalPartition, M);
346    count_1d_method!(tiles_in_global_partition_n, Tile, GlobalPartition, N);
347    count_1d_method!(
348        stage_partitions_in_global_partition_m,
349        StagePartition,
350        GlobalPartition,
351        M
352    );
353    count_1d_method!(
354        stage_partitions_in_global_partition_n,
355        StagePartition,
356        GlobalPartition,
357        N
358    );
359    count_1d_method!(stages_in_global_partition_m, Stage, GlobalPartition, M);
360    count_1d_method!(stages_in_global_partition_n, Stage, GlobalPartition, N);
361}