1use super::StageIdent;
2use super::size::{GlobalPartitionSize, MatmulDim, PartitionSize, StageSize, TileSize};
3
4#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)]
5pub struct TilingScheme {
8 pub tile_size: TileSize,
9 pub partition_size: PartitionSize,
10 pub stage_size: StageSize,
11 pub global_partition_size: GlobalPartitionSize,
12}
13
14impl TilingScheme {
15 pub fn builder() -> TilingSchemeBuilder {
17 TilingSchemeBuilder::default()
18 }
19}
20
21#[derive(Debug, Default)]
22pub struct TilingSchemeBuilder {
24 tile_size: Option<TileSize>,
25 partition_size: Option<PartitionSize>,
26 stage_size: Option<StageSize>,
27 global_partition_size: Option<GlobalPartitionSize>,
28}
29
30impl TilingSchemeBuilder {
31 pub fn with_tile_size(mut self, tile_size: TileSize) -> Self {
33 self.tile_size = Some(tile_size);
34 self
35 }
36
37 pub fn with_partition_size(mut self, partition_size: PartitionSize) -> Self {
39 self.partition_size = Some(partition_size);
40 self
41 }
42
43 pub fn with_stage_size(mut self, stage_size: StageSize) -> Self {
47 assert!(stage_size.k == 1, "Stage size k > 1 is not supported");
48 self.stage_size = Some(stage_size);
49 self
50 }
51
52 pub fn with_global_partition_size(
56 mut self,
57 global_partition_size: GlobalPartitionSize,
58 ) -> Self {
59 self.global_partition_size = Some(global_partition_size);
60 self
61 }
62
63 pub fn build(self) -> Result<TilingScheme, &'static str> {
65 Ok(TilingScheme {
66 tile_size: self.tile_size.ok_or("Missing tile_size")?,
67 partition_size: self.partition_size.ok_or("Missing tiles_per_partition")?,
68 stage_size: self.stage_size.ok_or("Missing partitions_per_stage")?,
69 global_partition_size: self
70 .global_partition_size
71 .unwrap_or(GlobalPartitionSize::new(1, 1, 1)),
72 })
73 }
74}
75
76#[derive(Debug, Clone, Copy, Eq, PartialEq)]
77enum TilingLevel {
78 GlobalPartition,
79 Stage,
80 StagePartition,
81 Tile,
82 Element,
83}
84
85impl TilingScheme {
86 fn try_count_1d(
87 &self,
88 child_level: TilingLevel,
89 parent_level: TilingLevel,
90 dim: MatmulDim,
91 ) -> Option<u32> {
92 use TilingLevel::*;
93
94 match (child_level, parent_level) {
95 (child, parent) if child == parent => Some(1),
96
97 (Stage, GlobalPartition) => match dim {
98 MatmulDim::M => Some(self.global_partition_size.m),
99 MatmulDim::N => Some(self.global_partition_size.n),
100 MatmulDim::K => None,
101 },
102
103 (StagePartition, Stage) => Some(self.stage_size.get(dim)),
104
105 (Tile, StagePartition) => Some(self.partition_size.get(dim)),
106
107 (Element, Tile) => Some(self.tile_size.get(dim)),
108
109 (StagePartition, GlobalPartition) => Some(
110 self.try_count_1d(StagePartition, Stage, dim)?
111 * self.try_count_1d(Stage, GlobalPartition, dim)?,
112 ),
113
114 (Tile, GlobalPartition) => Some(
115 self.try_count_1d(Tile, Stage, dim)?
116 * self.try_count_1d(Stage, GlobalPartition, dim)?,
117 ),
118
119 (Element, GlobalPartition) => Some(
120 self.try_count_1d(Element, Stage, dim)?
121 * self.try_count_1d(Stage, GlobalPartition, dim)?,
122 ),
123
124 (Tile, Stage) => Some(
125 self.try_count_1d(StagePartition, Stage, dim)?
126 * self.try_count_1d(Tile, StagePartition, dim)?,
127 ),
128
129 (Element, Stage) => {
130 Some(self.try_count_1d(Tile, Stage, dim)? * self.try_count_1d(Element, Tile, dim)?)
131 }
132
133 (Element, StagePartition) => Some(
134 self.try_count_1d(Tile, StagePartition, dim)?
135 * self.try_count_1d(Element, Tile, dim)?,
136 ),
137
138 _ => None,
140 }
141 }
142
143 fn try_count_2d(
144 &self,
145 child_level: TilingLevel,
146 parent_level: TilingLevel,
147 dim1: MatmulDim,
148 dim2: MatmulDim,
149 ) -> Option<u32> {
150 Some(
151 self.try_count_1d(child_level, parent_level, dim1)?
152 * self.try_count_1d(child_level, parent_level, dim2)?,
153 )
154 }
155
156 fn count_1d(&self, child_level: TilingLevel, parent_level: TilingLevel, dim: MatmulDim) -> u32 {
157 self.try_count_1d(child_level, parent_level, dim)
158 .unwrap_or_else(|| {
159 panic!("Invalid hierarchy: {parent_level:?} cannot contain {child_level:?}")
160 })
161 }
162
163 fn count_1d_ident_row(
164 &self,
165 child_level: TilingLevel,
166 parent_level: TilingLevel,
167 ident: StageIdent,
168 ) -> u32 {
169 match ident {
170 StageIdent::Lhs => self.count_1d(child_level, parent_level, MatmulDim::M),
171 StageIdent::Rhs => self.count_1d(child_level, parent_level, MatmulDim::K),
172 StageIdent::Acc => self.count_1d(child_level, parent_level, MatmulDim::M),
173 StageIdent::Out => self.count_1d(child_level, parent_level, MatmulDim::M),
174 }
175 }
176
177 fn count_1d_ident_col(
178 &self,
179 child_level: TilingLevel,
180 parent_level: TilingLevel,
181 ident: StageIdent,
182 ) -> u32 {
183 match ident {
184 StageIdent::Lhs => self.count_1d(child_level, parent_level, MatmulDim::K),
185 StageIdent::Rhs => self.count_1d(child_level, parent_level, MatmulDim::N),
186 StageIdent::Acc => self.count_1d(child_level, parent_level, MatmulDim::N),
187 StageIdent::Out => self.count_1d(child_level, parent_level, MatmulDim::N),
188 }
189 }
190
191 fn count_2d(
192 &self,
193 child_level: TilingLevel,
194 parent_level: TilingLevel,
195 dim1: MatmulDim,
196 dim2: MatmulDim,
197 ) -> u32 {
198 self.try_count_2d(child_level, parent_level, dim1, dim2)
199 .unwrap_or_else(|| {
200 panic!("Invalid hierarchy: {parent_level:?} cannot contain {child_level:?}")
201 })
202 }
203
204 fn count_2d_ident(
205 &self,
206 child_level: TilingLevel,
207 parent_level: TilingLevel,
208 ident: StageIdent,
209 ) -> u32 {
210 match ident {
211 StageIdent::Lhs => self.count_2d(child_level, parent_level, MatmulDim::M, MatmulDim::K),
212 StageIdent::Rhs => self.count_2d(child_level, parent_level, MatmulDim::K, MatmulDim::N),
213 StageIdent::Acc => self.count_2d(child_level, parent_level, MatmulDim::M, MatmulDim::N),
214 StageIdent::Out => self.count_2d(child_level, parent_level, MatmulDim::M, MatmulDim::N),
215 }
216 }
217}
218
219macro_rules! count_1d_method {
220 ($name:ident, $child:ident, $parent:ident, $dim:ident) => {
221 pub fn $name(&self) -> u32 {
222 self.count_1d(TilingLevel::$child, TilingLevel::$parent, MatmulDim::$dim)
223 }
224 };
225}
226
227macro_rules! count_1d_ident_row_method {
228 ($name:ident, $child:ident, $parent:ident) => {
229 pub fn $name<I: Into<StageIdent>>(&self, ident: I) -> u32 {
230 self.count_1d_ident_row(TilingLevel::$child, TilingLevel::$parent, ident.into())
231 }
232 };
233}
234
235macro_rules! count_1d_ident_col_method {
236 ($name:ident, $child:ident, $parent:ident) => {
237 pub fn $name<I: Into<StageIdent>>(&self, ident: I) -> u32 {
238 self.count_1d_ident_col(TilingLevel::$child, TilingLevel::$parent, ident.into())
239 }
240 };
241}
242
243macro_rules! count_2d_method {
244 ($name:ident, $child:ident, $parent:ident, $dim1:ident, $dim2:ident) => {
245 pub fn $name(&self) -> u32 {
246 self.count_2d(
247 TilingLevel::$child,
248 TilingLevel::$parent,
249 MatmulDim::$dim1,
250 MatmulDim::$dim2,
251 )
252 }
253 };
254}
255
256macro_rules! count_2d_ident_method {
257 ($name:ident, $child:ident, $parent:ident) => {
258 pub fn $name<I: Into<StageIdent>>(&self, ident: I) -> u32 {
259 self.count_2d_ident(TilingLevel::$child, TilingLevel::$parent, ident.into())
260 }
261 };
262}
263
264impl TilingScheme {
265 count_1d_method!(stage_partitions_in_stage_m, StagePartition, Stage, M);
266 count_1d_method!(stage_partitions_in_stage_n, StagePartition, Stage, N);
267 count_1d_method!(stage_partitions_in_stage_k, StagePartition, Stage, K);
268 count_1d_ident_row_method!(stage_partitions_in_stage_row, StagePartition, Stage);
269 count_1d_ident_col_method!(stage_partitions_in_stage_col, StagePartition, Stage);
270 count_2d_method!(stage_partitions_in_stage_mk, StagePartition, Stage, M, K);
271 count_2d_method!(stage_partitions_in_stage_nk, StagePartition, Stage, N, K);
272 count_2d_method!(stage_partitions_in_stage_mn, StagePartition, Stage, M, N);
273 count_2d_ident_method!(stage_partitions_in_stage, StagePartition, Stage);
274
275 count_1d_method!(tiles_in_stage_m, Tile, Stage, M);
276 count_1d_method!(tiles_in_stage_n, Tile, Stage, N);
277 count_1d_method!(tiles_in_stage_k, Tile, Stage, K);
278 count_1d_ident_row_method!(tiles_in_stage_row, Tile, Stage);
279 count_1d_ident_col_method!(tiles_in_stage_col, Tile, Stage);
280 count_2d_method!(tiles_in_stage_mk, Tile, Stage, M, K);
281 count_2d_method!(tiles_in_stage_nk, Tile, Stage, N, K);
282 count_2d_method!(tiles_in_stage_mn, Tile, Stage, M, N);
283 count_2d_ident_method!(tiles_in_stage, Tile, Stage);
284
285 count_1d_method!(elements_in_stage_m, Element, Stage, M);
286 count_1d_method!(elements_in_stage_n, Element, Stage, N);
287 count_1d_method!(elements_in_stage_k, Element, Stage, K);
288 count_1d_ident_row_method!(elements_in_stage_row, Element, Stage);
289 count_1d_ident_col_method!(elements_in_stage_col, Element, Stage);
290 count_2d_method!(elements_in_stage_mk, Element, Stage, M, K);
291 count_2d_method!(elements_in_stage_nk, Element, Stage, N, K);
292 count_2d_method!(elements_in_stage_mn, Element, Stage, M, N);
293 count_2d_ident_method!(elements_in_stage, Element, Stage);
294
295 count_1d_method!(tiles_in_stage_partition_m, Tile, StagePartition, M);
296 count_1d_method!(tiles_in_stage_partition_n, Tile, StagePartition, N);
297 count_1d_method!(tiles_in_stage_partition_k, Tile, StagePartition, K);
298 count_1d_ident_row_method!(tiles_in_stage_partition_row, Tile, StagePartition);
299 count_1d_ident_col_method!(tiles_in_stage_partition_col, Tile, StagePartition);
300 count_2d_method!(tiles_in_stage_partition_mk, Tile, StagePartition, M, K);
301 count_2d_method!(tiles_in_stage_partition_nk, Tile, StagePartition, N, K);
302 count_2d_method!(tiles_in_stage_partition_mn, Tile, StagePartition, M, N);
303 count_2d_ident_method!(tiles_in_stage_partition, Tile, StagePartition);
304
305 count_1d_method!(elements_in_stage_partition_m, Element, StagePartition, M);
306 count_1d_method!(elements_in_stage_partition_n, Element, StagePartition, N);
307 count_1d_method!(elements_in_stage_partition_k, Element, StagePartition, K);
308 count_1d_ident_row_method!(elements_in_stage_partition_row, Element, StagePartition);
309 count_1d_ident_col_method!(elements_in_stage_partition_col, Element, StagePartition);
310 count_2d_method!(
311 elements_in_stage_partition_mk,
312 Element,
313 StagePartition,
314 M,
315 K
316 );
317 count_2d_method!(
318 elements_in_stage_partition_nk,
319 Element,
320 StagePartition,
321 N,
322 K
323 );
324 count_2d_method!(
325 elements_in_stage_partition_mn,
326 Element,
327 StagePartition,
328 M,
329 N
330 );
331 count_2d_ident_method!(elements_in_stage_partition, Element, StagePartition);
332
333 count_1d_method!(elements_in_tile_m, Element, Tile, M);
334 count_1d_method!(elements_in_tile_n, Element, Tile, N);
335 count_1d_method!(elements_in_tile_k, Element, Tile, K);
336 count_1d_ident_row_method!(elements_in_tile_row, Element, Tile);
337 count_1d_ident_col_method!(elements_in_tile_col, Element, Tile);
338 count_2d_method!(elements_in_tile_mk, Element, Tile, M, K);
339 count_2d_method!(elements_in_tile_nk, Element, Tile, N, K);
340 count_2d_method!(elements_in_tile_mn, Element, Tile, M, N);
341 count_2d_ident_method!(elements_in_tile, Element, Tile);
342
343 count_1d_method!(elements_in_global_partition_m, Element, GlobalPartition, M);
344 count_1d_method!(elements_in_global_partition_n, Element, GlobalPartition, N);
345 count_1d_method!(tiles_in_global_partition_m, Tile, GlobalPartition, M);
346 count_1d_method!(tiles_in_global_partition_n, Tile, GlobalPartition, N);
347 count_1d_method!(
348 stage_partitions_in_global_partition_m,
349 StagePartition,
350 GlobalPartition,
351 M
352 );
353 count_1d_method!(
354 stage_partitions_in_global_partition_n,
355 StagePartition,
356 GlobalPartition,
357 N
358 );
359 count_1d_method!(stages_in_global_partition_m, Stage, GlobalPartition, M);
360 count_1d_method!(stages_in_global_partition_n, Stage, GlobalPartition, N);
361}