1use crate::components::{
2 MatmulKind, MatmulLineSizes, MatmulProblem, MatmulSelection, MatrixLayout, TilingScheme,
3 batch::{CubeCountPlanSelection, GlobalOrderSelection, HypercubeSelection, SmAllocation},
4 stage::PartitionBuffering,
5};
6use cubecl_core::{Runtime, client::ComputeClient};
7
8#[derive(Default, Clone, Copy, Debug)]
9pub enum TileSizeSelection {
10 MinTileSize,
12 #[default]
13 MaxTileSize,
15}
16
17#[derive(Default, Clone, Copy, Debug)]
18pub enum PartitionScaling {
19 #[default]
20 Enabled,
21 Disabled,
22}
23
24#[derive(Default, Clone, Copy, Debug)]
25pub enum StageScaling {
26 Enabled(u8),
27 #[default]
28 Disabled,
29}
30
31#[derive(Default, Clone, Copy, Debug)]
32pub struct UnitMatmulSelectionOptions {
33 pub tile: TileSizeSelection,
34 pub stage: StageScaling,
35 pub partition: PartitionScaling,
36}
37
38pub fn unit_matmul_selection<R: Runtime>(
40 client: &ComputeClient<R::Server>,
41 problem: &MatmulProblem,
42 plane_dim: u32,
43 double_buffering: bool,
44 line_size: &MatmulLineSizes,
45 options: UnitMatmulSelectionOptions,
46) -> MatmulSelection {
47 let kind: MatmulKind = problem.into();
48 let num_sms = client.properties().hardware.num_streaming_multiprocessors;
49 let min_tile_size = u8::max(line_size.lhs, line_size.rhs);
50 let min_tile_size = u8::max(line_size.out, min_tile_size) as u32;
51 let tile_size = u32::max(min_tile_size, 4);
52
53 match kind {
54 MatmulKind::General => general_unit_selector(
55 problem,
56 plane_dim,
57 double_buffering,
58 tile_size,
59 num_sms,
60 options,
61 ),
62 MatmulKind::MatVec => matvec_unit_selector(
63 problem,
64 plane_dim,
65 double_buffering,
66 tile_size,
67 num_sms,
68 options,
69 ),
70 MatmulKind::VecMat => vecmat_unit_selector(
71 problem,
72 plane_dim,
73 double_buffering,
74 tile_size,
75 num_sms,
76 options,
77 ),
78 MatmulKind::ScalarVec => {
79 scalarvec_unit_selector(problem, plane_dim, double_buffering, tile_size, num_sms)
80 }
81 MatmulKind::VecScalar => {
82 vecscalar_unit_selector(problem, plane_dim, double_buffering, tile_size, num_sms)
83 }
84 MatmulKind::InnerProduct => {
85 inner_product_unit_selector(problem, plane_dim, double_buffering, tile_size, num_sms)
86 }
87 MatmulKind::OuterProduct => {
88 outer_product_unit_selector(problem, plane_dim, double_buffering, tile_size, num_sms)
89 }
90 MatmulKind::ScalarProduct => {
91 scalar_product_unit_selector(problem, plane_dim, double_buffering, tile_size, num_sms)
92 }
93 }
94}
95
96fn general_unit_selector(
98 problem: &MatmulProblem,
99 plane_dim: u32,
100 double_buffering: bool,
101 tile_size: u32,
102 num_sms: Option<u32>,
103 options: UnitMatmulSelectionOptions,
104) -> MatmulSelection {
105 use MatrixLayout::*;
106
107 let (tile_size, mut partition_size) =
109 match (problem.lhs_layout, problem.rhs_layout, options.tile) {
110 (RowMajor, _, TileSizeSelection::MinTileSize) => (
111 (1, tile_size, tile_size),
112 (
113 scale_partition(options.partition, problem.m, 4, 9),
114 2,
115 scale_partition(options.partition, problem.k, 2, 10),
116 ),
117 ),
118 (ColMajor, RowMajor, TileSizeSelection::MinTileSize) => (
119 (tile_size, tile_size, 1),
120 (2, 2, scale_partition(options.partition, problem.k, 3, 10)),
121 ),
122 (ColMajor, ColMajor, _) | (_, _, TileSizeSelection::MaxTileSize) => (
123 (tile_size, tile_size, tile_size),
124 (
125 scale_partition(options.partition, problem.m, 2, 9),
126 2,
127 scale_partition(options.partition, problem.k, 2, 9),
128 ),
129 ),
130 };
131
132 let mut num_plane = 8;
133
134 if double_buffering {
135 if partition_size.0 > 2 {
136 partition_size.0 /= 2;
137 }
138 if partition_size.2 > 2 {
139 partition_size.2 /= 2;
140 }
141 num_plane /= 2;
142 }
143
144 selection(
145 tile_size,
146 partition_size,
147 PartitionBuffering::Single,
148 plane_dim,
149 StageSelection::WithPlane {
150 plane_dim,
151 num_plane,
152 },
153 num_sms,
154 GlobalOrderSelection::SwizzleRow {
155 m: problem.m as u32,
156 w: 4,
157 },
158 options.stage,
159 )
160}
161
162fn matvec_unit_selector(
164 problem: &MatmulProblem,
165 plane_dim: u32,
166 _double_buffering: bool,
167 tile_size: u32,
168 num_sms: Option<u32>,
169 _options: UnitMatmulSelectionOptions,
170) -> MatmulSelection {
171 let (tile_size, partition_size) = match (problem.lhs_layout, problem.rhs_layout) {
172 (MatrixLayout::RowMajor, _) => ((1, 1, tile_size), (1, 1, tile_size * 2)),
173 _ => ((tile_size, 1, tile_size), (1, 1, 1)),
174 };
175
176 selection(
177 tile_size,
178 partition_size,
179 PartitionBuffering::Single,
180 plane_dim,
181 StageSelection::Fixed {
182 m: plane_dim / 2,
183 n: 2,
184 },
185 num_sms,
186 GlobalOrderSelection::Default,
187 StageScaling::Disabled,
188 )
189}
190
191fn vecmat_unit_selector(
193 _problem: &MatmulProblem,
194 plane_dim: u32,
195 _double_buffering: bool,
196 tile_size: u32,
197 num_sms: Option<u32>,
198 _options: UnitMatmulSelectionOptions,
199) -> MatmulSelection {
200 let (tile_size, partition_size) = ((1, tile_size, tile_size), (1, 1, 1));
201
202 selection(
203 tile_size,
204 partition_size,
205 PartitionBuffering::Single,
206 plane_dim,
207 StageSelection::Fixed {
208 m: 2,
209 n: plane_dim / 2,
210 },
211 num_sms,
212 GlobalOrderSelection::Default,
213 StageScaling::Disabled,
214 )
215}
216
217fn scalarvec_unit_selector(
219 problem: &MatmulProblem,
220 plane_dim: u32,
221 _double_buffering: bool,
222 tile_size: u32,
223 num_sms: Option<u32>,
224) -> MatmulSelection {
225 use MatrixLayout::*;
226 let (tile_size, partition_size) = match (problem.lhs_layout, problem.rhs_layout) {
227 (RowMajor, RowMajor) => ((1, tile_size, tile_size), (1, 2, 1)),
228 (RowMajor, ColMajor) => ((1, tile_size, tile_size), (1, 2, 1)),
229 (ColMajor, RowMajor) => ((1, tile_size, tile_size), (1, 2, 1)),
230 (ColMajor, ColMajor) => ((1, tile_size, tile_size), (2, 2, 1)),
231 };
232
233 selection(
234 tile_size,
235 partition_size,
236 PartitionBuffering::Single,
237 plane_dim,
238 StageSelection::Fixed {
239 m: 2,
240 n: plane_dim / 2,
241 },
242 num_sms,
243 GlobalOrderSelection::Default,
244 StageScaling::Disabled,
245 )
246}
247
248fn vecscalar_unit_selector(
250 _problem: &MatmulProblem,
251 plane_dim: u32,
252 _double_buffering: bool,
253 tile_size: u32,
254 num_sms: Option<u32>,
255) -> MatmulSelection {
256 let (tile_size, partition_size) = ((tile_size, 1, 1), (1, 1, 1));
257
258 selection(
259 tile_size,
260 partition_size,
261 PartitionBuffering::Single,
262 plane_dim,
263 StageSelection::Fixed {
264 m: plane_dim / 2,
265 n: 2,
266 },
267 num_sms,
268 GlobalOrderSelection::Default,
269 StageScaling::Disabled,
270 )
271}
272
273fn inner_product_unit_selector(
275 problem: &MatmulProblem,
276 plane_dim: u32,
277 _double_buffering: bool,
278 tile_size: u32,
279 num_sms: Option<u32>,
280) -> MatmulSelection {
281 use MatrixLayout::*;
282 let (tile_size, partition_size) = match (problem.lhs_layout, problem.rhs_layout) {
283 (RowMajor, RowMajor) => ((1, 1, tile_size), (1, 1, 1)),
284 (RowMajor, ColMajor) => ((1, 1, tile_size), (1, 1, 1)),
285 (ColMajor, RowMajor) => ((1, 1, tile_size), (1, 1, 1)),
286 (ColMajor, ColMajor) => ((1, 1, tile_size), (1, 1, 1)),
287 };
288
289 selection(
290 tile_size,
291 partition_size,
292 PartitionBuffering::Single,
293 plane_dim,
294 StageSelection::Fixed { m: plane_dim, n: 1 }, num_sms,
296 GlobalOrderSelection::Default,
297 StageScaling::Disabled,
298 )
299}
300
301fn outer_product_unit_selector(
303 _problem: &MatmulProblem,
304 plane_dim: u32,
305 _double_buffering: bool,
306 tile_size: u32,
307 num_sms: Option<u32>,
308) -> MatmulSelection {
309 let (tile_size, partition_size) = ((tile_size, tile_size, 1), (1, 1, 1));
310
311 selection(
312 tile_size,
313 partition_size,
314 PartitionBuffering::Single,
315 plane_dim,
316 StageSelection::Fixed { m: 8, n: 8 },
317 num_sms,
318 GlobalOrderSelection::Default,
319 StageScaling::Disabled,
320 )
321}
322
323fn scalar_product_unit_selector(
325 _problem: &MatmulProblem,
326 plane_dim: u32,
327 _double_buffering: bool,
328 _tile_size: u32,
329 num_sms: Option<u32>,
330) -> MatmulSelection {
331 let (tile_size, partition_size) = ((1, 1, 1), (1, 1, 1));
332
333 selection(
334 tile_size,
335 partition_size,
336 PartitionBuffering::Single,
337 plane_dim,
338 StageSelection::WithPlane {
339 plane_dim,
340 num_plane: 1,
341 },
342 num_sms,
343 GlobalOrderSelection::Default,
344 StageScaling::Disabled,
345 )
346}
347
348enum StageSelection {
349 WithPlane { plane_dim: u32, num_plane: u32 },
350 Fixed { m: u32, n: u32 },
351}
352
353impl StageSelection {
354 fn into_stages(self) -> (u32, u32) {
355 match self {
356 StageSelection::WithPlane {
357 plane_dim: plane_size,
358 num_plane: num_planes,
359 } => {
360 let num_units = num_planes * plane_size;
361 closest_factor_pair(num_units)
362 }
363 StageSelection::Fixed { m, n } => (m, n),
364 }
365 }
366}
367
368#[allow(clippy::too_many_arguments)]
369fn selection(
370 t: (u32, u32, u32),
371 p: (u32, u32, u32),
372 buffering: PartitionBuffering,
373 plane_dim: u32,
374 stage: StageSelection,
375 num_sms: Option<u32>,
376 global_order_config: GlobalOrderSelection,
377 stage_scaling: StageScaling,
378) -> MatmulSelection {
379 let (stage_size_m, stage_size_n) = stage.into_stages();
380
381 let (stage_size_m, stage_size_n) = match stage_scaling {
382 StageScaling::Enabled(f) => (stage_size_m / f as u32, stage_size_n / f as u32),
383 StageScaling::Disabled => (stage_size_m, stage_size_n),
384 };
385
386 let tiling_scheme = TilingScheme::builder()
387 .with_tile_size(t.into())
388 .with_partition_size(p.into())
389 .with_stage_size((stage_size_m, stage_size_n, 1).into())
390 .build()
391 .unwrap();
392
393 let cube_count_plan = match num_sms {
394 Some(num_sms) => CubeCountPlanSelection::Sm {
395 num_sms,
396 sm_usage: SmAllocation::Exact,
397 cubes_first: false,
398 },
399 None => CubeCountPlanSelection::Flattened,
400 };
401
402 let hypercube = HypercubeSelection::builder(&tiling_scheme)
403 .global_order(global_order_config)
404 .cube_count_plan(cube_count_plan)
405 .build();
406
407 MatmulSelection::builder(tiling_scheme, plane_dim)
408 .partition_buffering(buffering)
409 .hypercube_config(hypercube)
410 .build()
411}
412
413pub fn closest_factor_pair(n: u32) -> (u32, u32) {
416 let sqrt_n = (n as f64).sqrt() as u32;
417 for a in (1..=sqrt_n).rev() {
418 if n.is_multiple_of(a) {
419 return (n / a, a);
420 }
421 }
422 (n, 1)
423}
424
425fn scale_partition(setting: PartitionScaling, axis: usize, max_exp: u32, div_exp: u32) -> u32 {
426 if let PartitionScaling::Disabled = setting {
427 return 2u32.pow(max_exp);
428 }
429
430 let exp = u32::min((axis as u32 / 2u32.pow(div_exp)) + 1, max_exp);
431 2u32.pow(exp)
432}