1use super::Ident;
2use super::size::{GlobalPartitionSize, MatmulDim, PartitionSize, StageSize, TileSize};
3
4#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)]
5pub struct TilingScheme {
8 pub tile_size: TileSize,
9 pub partition_size: PartitionSize,
10 pub stage_size: StageSize,
11 pub global_partition_size: GlobalPartitionSize,
12}
13
14impl TilingScheme {
15 pub fn builder() -> TilingSchemeBuilder {
17 TilingSchemeBuilder::default()
18 }
19}
20
21#[derive(Debug, Default)]
22pub struct TilingSchemeBuilder {
24 tile_size: Option<TileSize>,
25 partition_size: Option<PartitionSize>,
26 stage_size: Option<StageSize>,
27 global_partition_size: Option<GlobalPartitionSize>,
28}
29
30impl TilingSchemeBuilder {
31 pub fn with_tile_size(mut self, tile_size: TileSize) -> Self {
33 self.tile_size = Some(tile_size);
34 self
35 }
36
37 pub fn with_partition_size(mut self, partition_size: PartitionSize) -> Self {
39 self.partition_size = Some(partition_size);
40 self
41 }
42
43 pub fn with_stage_size(mut self, stage_size: StageSize) -> Self {
47 assert!(stage_size.k == 1, "Stage size k > 1 is not supported");
48 self.stage_size = Some(stage_size);
49 self
50 }
51
52 pub fn with_global_partition_size(
56 mut self,
57 global_partition_size: GlobalPartitionSize,
58 ) -> Self {
59 self.global_partition_size = Some(global_partition_size);
60 self
61 }
62
63 pub fn build(self) -> Result<TilingScheme, &'static str> {
65 Ok(TilingScheme {
66 tile_size: self.tile_size.ok_or("Missing tile_size")?,
67 partition_size: self.partition_size.ok_or("Missing tiles_per_partition")?,
68 stage_size: self.stage_size.ok_or("Missing partitions_per_stage")?,
69 global_partition_size: self
70 .global_partition_size
71 .unwrap_or(GlobalPartitionSize::new(1, 1, 1)),
72 })
73 }
74}
75
76#[derive(Debug, Clone, Copy, Eq, PartialEq)]
77enum TilingLevel {
78 GlobalPartition,
79 Stage,
80 StagePartition,
81 Tile,
82 Element,
83}
84
85impl TilingScheme {
86 fn try_count_1d(
87 &self,
88 child_level: TilingLevel,
89 parent_level: TilingLevel,
90 dim: MatmulDim,
91 ) -> Option<u32> {
92 use TilingLevel::*;
93
94 match (child_level, parent_level) {
95 (child, parent) if child == parent => Some(1),
96
97 (Stage, GlobalPartition) => match dim {
98 MatmulDim::M => Some(self.global_partition_size.m),
99 MatmulDim::N => Some(self.global_partition_size.n),
100 MatmulDim::K => None,
101 },
102
103 (StagePartition, Stage) => Some(self.stage_size.get(dim)),
104
105 (Tile, StagePartition) => Some(self.partition_size.get(dim)),
106
107 (Element, Tile) => Some(self.tile_size.get(dim)),
108
109 (StagePartition, GlobalPartition) => Some(
110 self.try_count_1d(StagePartition, Stage, dim)?
111 * self.try_count_1d(Stage, GlobalPartition, dim)?,
112 ),
113
114 (Tile, GlobalPartition) => Some(
115 self.try_count_1d(Tile, Stage, dim)?
116 * self.try_count_1d(Stage, GlobalPartition, dim)?,
117 ),
118
119 (Element, GlobalPartition) => Some(
120 self.try_count_1d(Element, Stage, dim)?
121 * self.try_count_1d(Stage, GlobalPartition, dim)?,
122 ),
123
124 (Tile, Stage) => Some(
125 self.try_count_1d(StagePartition, Stage, dim)?
126 * self.try_count_1d(Tile, StagePartition, dim)?,
127 ),
128
129 (Element, Stage) => {
130 Some(self.try_count_1d(Tile, Stage, dim)? * self.try_count_1d(Element, Tile, dim)?)
131 }
132
133 (Element, StagePartition) => Some(
134 self.try_count_1d(Tile, StagePartition, dim)?
135 * self.try_count_1d(Element, Tile, dim)?,
136 ),
137
138 _ => None,
140 }
141 }
142
143 fn try_count_2d(
144 &self,
145 child_level: TilingLevel,
146 parent_level: TilingLevel,
147 dim1: MatmulDim,
148 dim2: MatmulDim,
149 ) -> Option<u32> {
150 Some(
151 self.try_count_1d(child_level, parent_level, dim1)?
152 * self.try_count_1d(child_level, parent_level, dim2)?,
153 )
154 }
155
156 fn count_1d(&self, child_level: TilingLevel, parent_level: TilingLevel, dim: MatmulDim) -> u32 {
157 self.try_count_1d(child_level, parent_level, dim)
158 .unwrap_or_else(|| {
159 panic!("Invalid hierarchy: {parent_level:?} cannot contain {child_level:?}")
160 })
161 }
162
163 fn count_1d_ident_row<I: Into<Ident>>(
164 &self,
165 child_level: TilingLevel,
166 parent_level: TilingLevel,
167 ident: I,
168 ) -> u32 {
169 match ident.into() {
170 Ident::Lhs => self.count_1d(child_level, parent_level, MatmulDim::M),
171 Ident::Rhs => self.count_1d(child_level, parent_level, MatmulDim::K),
172 Ident::Out => self.count_1d(child_level, parent_level, MatmulDim::M),
173 }
174 }
175
176 fn count_1d_ident_col<I: Into<Ident>>(
177 &self,
178 child_level: TilingLevel,
179 parent_level: TilingLevel,
180 ident: I,
181 ) -> u32 {
182 match ident.into() {
183 Ident::Lhs => self.count_1d(child_level, parent_level, MatmulDim::K),
184 Ident::Rhs => self.count_1d(child_level, parent_level, MatmulDim::N),
185 Ident::Out => self.count_1d(child_level, parent_level, MatmulDim::N),
186 }
187 }
188
189 fn count_2d(
190 &self,
191 child_level: TilingLevel,
192 parent_level: TilingLevel,
193 dim1: MatmulDim,
194 dim2: MatmulDim,
195 ) -> u32 {
196 self.try_count_2d(child_level, parent_level, dim1, dim2)
197 .unwrap_or_else(|| {
198 panic!("Invalid hierarchy: {parent_level:?} cannot contain {child_level:?}")
199 })
200 }
201
202 fn count_2d_ident<I: Into<Ident>>(
203 &self,
204 child_level: TilingLevel,
205 parent_level: TilingLevel,
206 ident: I,
207 ) -> u32 {
208 match ident.into() {
209 Ident::Lhs => self.count_2d(child_level, parent_level, MatmulDim::M, MatmulDim::K),
210 Ident::Rhs => self.count_2d(child_level, parent_level, MatmulDim::K, MatmulDim::N),
211 Ident::Out => self.count_2d(child_level, parent_level, MatmulDim::M, MatmulDim::N),
212 }
213 }
214}
215
216macro_rules! count_1d_method {
217 ($name:ident, $child:ident, $parent:ident, $dim:ident) => {
218 pub fn $name(&self) -> u32 {
219 self.count_1d(TilingLevel::$child, TilingLevel::$parent, MatmulDim::$dim)
220 }
221 };
222}
223
224macro_rules! count_1d_ident_row_method {
225 ($name:ident, $child:ident, $parent:ident) => {
226 pub fn $name<I: Into<Ident>>(&self, ident: I) -> u32 {
227 self.count_1d_ident_row(TilingLevel::$child, TilingLevel::$parent, ident)
228 }
229 };
230}
231
232macro_rules! count_1d_ident_col_method {
233 ($name:ident, $child:ident, $parent:ident) => {
234 pub fn $name<I: Into<Ident>>(&self, ident: I) -> u32 {
235 self.count_1d_ident_col(TilingLevel::$child, TilingLevel::$parent, ident)
236 }
237 };
238}
239
240macro_rules! count_2d_method {
241 ($name:ident, $child:ident, $parent:ident, $dim1:ident, $dim2:ident) => {
242 pub fn $name(&self) -> u32 {
243 self.count_2d(
244 TilingLevel::$child,
245 TilingLevel::$parent,
246 MatmulDim::$dim1,
247 MatmulDim::$dim2,
248 )
249 }
250 };
251}
252
253macro_rules! count_2d_ident_method {
254 ($name:ident, $child:ident, $parent:ident) => {
255 pub fn $name<I: Into<Ident>>(&self, ident: I) -> u32 {
256 self.count_2d_ident(TilingLevel::$child, TilingLevel::$parent, ident)
257 }
258 };
259}
260
261impl TilingScheme {
262 count_1d_method!(stage_partitions_in_stage_m, StagePartition, Stage, M);
263 count_1d_method!(stage_partitions_in_stage_n, StagePartition, Stage, N);
264 count_1d_method!(stage_partitions_in_stage_k, StagePartition, Stage, K);
265 count_1d_ident_row_method!(stage_partitions_in_stage_row, StagePartition, Stage);
266 count_1d_ident_col_method!(stage_partitions_in_stage_col, StagePartition, Stage);
267 count_2d_method!(stage_partitions_in_stage_mk, StagePartition, Stage, M, K);
268 count_2d_method!(stage_partitions_in_stage_nk, StagePartition, Stage, N, K);
269 count_2d_method!(stage_partitions_in_stage_mn, StagePartition, Stage, M, N);
270 count_2d_ident_method!(stage_partitions_in_stage, StagePartition, Stage);
271
272 count_1d_method!(tiles_in_stage_m, Tile, Stage, M);
273 count_1d_method!(tiles_in_stage_n, Tile, Stage, N);
274 count_1d_method!(tiles_in_stage_k, Tile, Stage, K);
275 count_1d_ident_row_method!(tiles_in_stage_row, Tile, Stage);
276 count_1d_ident_col_method!(tiles_in_stage_col, Tile, Stage);
277 count_2d_method!(tiles_in_stage_mk, Tile, Stage, M, K);
278 count_2d_method!(tiles_in_stage_nk, Tile, Stage, N, K);
279 count_2d_method!(tiles_in_stage_mn, Tile, Stage, M, N);
280 count_2d_ident_method!(tiles_in_stage, Tile, Stage);
281
282 count_1d_method!(elements_in_stage_m, Element, Stage, M);
283 count_1d_method!(elements_in_stage_n, Element, Stage, N);
284 count_1d_method!(elements_in_stage_k, Element, Stage, K);
285 count_1d_ident_row_method!(elements_in_stage_row, Element, Stage);
286 count_1d_ident_col_method!(elements_in_stage_col, Element, Stage);
287 count_2d_method!(elements_in_stage_mk, Element, Stage, M, K);
288 count_2d_method!(elements_in_stage_nk, Element, Stage, N, K);
289 count_2d_method!(elements_in_stage_mn, Element, Stage, M, N);
290 count_2d_ident_method!(elements_in_stage, Element, Stage);
291
292 count_1d_method!(tiles_in_stage_partition_m, Tile, StagePartition, M);
293 count_1d_method!(tiles_in_stage_partition_n, Tile, StagePartition, N);
294 count_1d_method!(tiles_in_stage_partition_k, Tile, StagePartition, K);
295 count_1d_ident_row_method!(tiles_in_stage_partition_row, Tile, StagePartition);
296 count_1d_ident_col_method!(tiles_in_stage_partition_col, Tile, StagePartition);
297 count_2d_method!(tiles_in_stage_partition_mk, Tile, StagePartition, M, K);
298 count_2d_method!(tiles_in_stage_partition_nk, Tile, StagePartition, N, K);
299 count_2d_method!(tiles_in_stage_partition_mn, Tile, StagePartition, M, N);
300 count_2d_ident_method!(tiles_in_stage_partition, Tile, StagePartition);
301
302 count_1d_method!(elements_in_stage_partition_m, Element, StagePartition, M);
303 count_1d_method!(elements_in_stage_partition_n, Element, StagePartition, N);
304 count_1d_method!(elements_in_stage_partition_k, Element, StagePartition, K);
305 count_1d_ident_row_method!(elements_in_stage_partition_row, Element, StagePartition);
306 count_1d_ident_col_method!(elements_in_stage_partition_col, Element, StagePartition);
307 count_2d_method!(
308 elements_in_stage_partition_mk,
309 Element,
310 StagePartition,
311 M,
312 K
313 );
314 count_2d_method!(
315 elements_in_stage_partition_nk,
316 Element,
317 StagePartition,
318 N,
319 K
320 );
321 count_2d_method!(
322 elements_in_stage_partition_mn,
323 Element,
324 StagePartition,
325 M,
326 N
327 );
328 count_2d_ident_method!(elements_in_stage_partition, Element, StagePartition);
329
330 count_1d_method!(elements_in_tile_m, Element, Tile, M);
331 count_1d_method!(elements_in_tile_n, Element, Tile, N);
332 count_1d_method!(elements_in_tile_k, Element, Tile, K);
333 count_1d_ident_row_method!(elements_in_tile_row, Element, Tile);
334 count_1d_ident_col_method!(elements_in_tile_col, Element, Tile);
335 count_2d_method!(elements_in_tile_mk, Element, Tile, M, K);
336 count_2d_method!(elements_in_tile_nk, Element, Tile, N, K);
337 count_2d_method!(elements_in_tile_mn, Element, Tile, M, N);
338 count_2d_ident_method!(elements_in_tile, Element, Tile);
339
340 count_1d_method!(elements_in_global_partition_m, Element, GlobalPartition, M);
341 count_1d_method!(elements_in_global_partition_n, Element, GlobalPartition, N);
342 count_1d_method!(tiles_in_global_partition_m, Tile, GlobalPartition, M);
343 count_1d_method!(tiles_in_global_partition_n, Tile, GlobalPartition, N);
344 count_1d_method!(
345 stage_partitions_in_global_partition_m,
346 StagePartition,
347 GlobalPartition,
348 M
349 );
350 count_1d_method!(
351 stage_partitions_in_global_partition_n,
352 StagePartition,
353 GlobalPartition,
354 N
355 );
356 count_1d_method!(stages_in_global_partition_m, Stage, GlobalPartition, M);
357 count_1d_method!(stages_in_global_partition_n, Stage, GlobalPartition, N);
358}