1use cubecl_core::{Runtime, client::ComputeClient};
2
3use crate::components::{
4 MatmulKind, MatmulProblem, MatmulSelection, MatrixLayout, TilingScheme,
5 batch::{CubeCountPlanSelection, GlobalOrderSelection, HypercubeSelection, SmAllocation},
6 stage::PartitionBuffering,
7};
8
9#[derive(Default, Clone, Copy, Debug)]
10pub enum TileSizeSelection {
11 MinTileSize,
13 #[default]
14 MaxTileSize,
16}
17
18#[derive(Default, Clone, Copy, Debug)]
19pub enum PartitionScaling {
20 #[default]
21 Enabled,
22 Disabled,
23}
24
25#[derive(Default, Clone, Copy, Debug)]
26pub enum StageScaling {
27 Enabled(u8),
28 #[default]
29 Disabled,
30}
31
32#[derive(Default, Clone, Copy, Debug)]
33pub struct UnitMatmulSelectionOptions {
34 pub tile: TileSizeSelection,
35 pub stage: StageScaling,
36 pub partition: PartitionScaling,
37}
38
39pub fn unit_matmul_selection<R: Runtime>(
41 client: &ComputeClient<R::Server, R::Channel>,
42 problem: &MatmulProblem,
43 plane_dim: u32,
44 double_buffering: bool,
45 options: UnitMatmulSelectionOptions,
46) -> MatmulSelection {
47 let kind: MatmulKind = problem.into();
48 let num_sms = client.properties().hardware.num_streaming_multiprocessors;
49
50 match kind {
51 MatmulKind::General => {
52 general_unit_selector(problem, plane_dim, double_buffering, num_sms, options)
53 }
54 MatmulKind::MatVec => {
55 matvec_unit_selector(problem, plane_dim, double_buffering, num_sms, options)
56 }
57 MatmulKind::VecMat => vecmat_unit_selector(problem, plane_dim, double_buffering, num_sms),
58 MatmulKind::ScalarVec => {
59 scalarvec_unit_selector(problem, plane_dim, double_buffering, num_sms)
60 }
61 MatmulKind::VecScalar => {
62 vecscalar_unit_selector(problem, plane_dim, double_buffering, num_sms)
63 }
64 MatmulKind::InnerProduct => {
65 inner_product_unit_selector(problem, plane_dim, double_buffering, num_sms)
66 }
67 MatmulKind::OuterProduct => {
68 outer_product_unit_selector(problem, plane_dim, double_buffering, num_sms)
69 }
70 MatmulKind::ScalarProduct => {
71 scalar_product_unit_selector(problem, plane_dim, double_buffering, num_sms)
72 }
73 }
74}
75
76fn general_unit_selector(
78 problem: &MatmulProblem,
79 plane_dim: u32,
80 double_buffering: bool,
81 num_sms: Option<u32>,
82 options: UnitMatmulSelectionOptions,
83) -> MatmulSelection {
84 use MatrixLayout::*;
85
86 let (tile_size, mut partition_size) =
88 match (problem.lhs_layout, problem.rhs_layout, options.tile) {
89 (RowMajor, _, TileSizeSelection::MinTileSize) => (
90 (1, 4, 4),
91 (
92 scale_partition(options.partition, problem.m, 4, 9),
93 2,
94 scale_partition(options.partition, problem.k, 2, 10),
95 ),
96 ),
97 (ColMajor, RowMajor, TileSizeSelection::MinTileSize) => (
98 (4, 4, 1),
99 (2, 2, scale_partition(options.partition, problem.k, 3, 10)),
100 ),
101 (ColMajor, ColMajor, _) | (_, _, TileSizeSelection::MaxTileSize) => (
102 (4, 4, 4),
103 (
104 scale_partition(options.partition, problem.m, 2, 9),
105 2,
106 scale_partition(options.partition, problem.k, 2, 9),
107 ),
108 ),
109 };
110
111 if double_buffering && partition_size.2 > 2 {
113 partition_size.2 /= 2;
114 }
115
116 selection(
117 tile_size,
118 partition_size,
119 PartitionBuffering::Single,
120 plane_dim,
121 StageSelection::WithPlane {
122 plane_dim,
123 num_plane: 8,
124 },
125 num_sms,
126 GlobalOrderSelection::SwizzleRow {
127 m: problem.m as u32,
128 w: 4,
129 },
130 options.stage,
131 )
132}
133
134fn matvec_unit_selector(
136 problem: &MatmulProblem,
137 plane_dim: u32,
138 _double_buffering: bool,
139 num_sms: Option<u32>,
140 options: UnitMatmulSelectionOptions,
141) -> MatmulSelection {
142 use MatrixLayout::*;
143 let (tile_size, partition_size) = match (problem.lhs_layout, problem.rhs_layout, options.tile) {
144 (RowMajor, _, TileSizeSelection::MinTileSize) => ((1, 1, 4), (1, 1, 4)),
145 _ => ((4, 1, 4), (1, 1, 4)),
146 };
147
148 selection(
149 tile_size,
150 partition_size,
151 PartitionBuffering::Single,
152 plane_dim,
153 StageSelection::Fixed { m: 8, n: 8 },
154 num_sms,
155 GlobalOrderSelection::Default,
156 options.stage,
157 )
158}
159
160fn vecmat_unit_selector(
162 problem: &MatmulProblem,
163 plane_dim: u32,
164 _double_buffering: bool,
165 num_sms: Option<u32>,
166) -> MatmulSelection {
167 use MatrixLayout::*;
168 let (tile_size, partition_size) = match (problem.lhs_layout, problem.rhs_layout) {
169 (RowMajor, RowMajor) => ((1, 4, 4), (1, 1, 4)),
170 (RowMajor, ColMajor) => (
171 (1, 4, 4),
172 (2, 1, scale_partition(Default::default(), problem.k, 3, 7)),
173 ),
174 (ColMajor, RowMajor) => ((1, 4, 4), (1, 1, 4)),
175 (ColMajor, ColMajor) => (
176 (1, 4, 4),
177 (
178 2,
179 1,
180 scale_partition(PartitionScaling::Enabled, problem.k, 3, 7),
181 ),
182 ),
183 };
184
185 selection(
186 tile_size,
187 partition_size,
188 PartitionBuffering::Single,
189 plane_dim,
190 StageSelection::Fixed { m: 8, n: 8 },
191 num_sms,
192 GlobalOrderSelection::Default,
193 StageScaling::Disabled,
194 )
195}
196
197fn scalarvec_unit_selector(
199 problem: &MatmulProblem,
200 plane_dim: u32,
201 _double_buffering: bool,
202 num_sms: Option<u32>,
203) -> MatmulSelection {
204 use MatrixLayout::*;
205 let (tile_size, partition_size) = match (problem.lhs_layout, problem.rhs_layout) {
206 (RowMajor, RowMajor) => ((1, 4, 4), (1, 2, 1)),
207 (RowMajor, ColMajor) => ((1, 4, 4), (1, 2, 1)),
208 (ColMajor, RowMajor) => ((1, 4, 4), (1, 2, 1)),
209 (ColMajor, ColMajor) => ((1, 4, 4), (2, 2, 1)),
210 };
211
212 selection(
213 tile_size,
214 partition_size,
215 PartitionBuffering::Single,
216 plane_dim,
217 StageSelection::Fixed { m: 4, n: 8 },
218 num_sms,
219 GlobalOrderSelection::Default,
220 StageScaling::Disabled,
221 )
222}
223
224fn vecscalar_unit_selector(
226 _problem: &MatmulProblem,
227 plane_dim: u32,
228 _double_buffering: bool,
229 num_sms: Option<u32>,
230) -> MatmulSelection {
231 let (tile_size, partition_size) = ((4, 1, 4), (1, 1, 2));
232
233 selection(
234 tile_size,
235 partition_size,
236 PartitionBuffering::Single,
237 plane_dim,
238 StageSelection::Fixed { m: 8, n: 4 },
239 num_sms,
240 GlobalOrderSelection::Default,
241 StageScaling::Disabled,
242 )
243}
244
245fn inner_product_unit_selector(
247 problem: &MatmulProblem,
248 plane_dim: u32,
249 _double_buffering: bool,
250 num_sms: Option<u32>,
251) -> MatmulSelection {
252 use MatrixLayout::*;
253 let (tile_size, partition_size) = match (problem.lhs_layout, problem.rhs_layout) {
254 (RowMajor, RowMajor) => ((1, 1, 4), (1, 1, 1)),
255 (RowMajor, ColMajor) => ((1, 1, 4), (1, 1, 1)),
256 (ColMajor, RowMajor) => ((1, 1, 4), (1, 1, 1)),
257 (ColMajor, ColMajor) => ((1, 1, 4), (1, 1, 1)),
258 };
259
260 selection(
261 tile_size,
262 partition_size,
263 PartitionBuffering::Single,
264 plane_dim,
265 StageSelection::Fixed { m: 4, n: 8 },
266 num_sms,
267 GlobalOrderSelection::Default,
268 StageScaling::Disabled,
269 )
270}
271
272fn outer_product_unit_selector(
274 _problem: &MatmulProblem,
275 plane_dim: u32,
276 _double_buffering: bool,
277 num_sms: Option<u32>,
278) -> MatmulSelection {
279 let (tile_size, partition_size) = ((4, 4, 1), (1, 1, 1));
280
281 selection(
282 tile_size,
283 partition_size,
284 PartitionBuffering::Single,
285 plane_dim,
286 StageSelection::Fixed { m: 8, n: 8 },
287 num_sms,
288 GlobalOrderSelection::Default,
289 StageScaling::Disabled,
290 )
291}
292
293fn scalar_product_unit_selector(
295 _problem: &MatmulProblem,
296 plane_dim: u32,
297 _double_buffering: bool,
298 num_sms: Option<u32>,
299) -> MatmulSelection {
300 let (tile_size, partition_size) = ((1, 1, 1), (1, 1, 1));
301
302 selection(
303 tile_size,
304 partition_size,
305 PartitionBuffering::Single,
306 plane_dim,
307 StageSelection::WithPlane {
308 plane_dim,
309 num_plane: 8,
310 },
311 num_sms,
312 GlobalOrderSelection::Default,
313 StageScaling::Disabled,
314 )
315}
316
317enum StageSelection {
318 WithPlane { plane_dim: u32, num_plane: u32 },
319 Fixed { m: u32, n: u32 },
320}
321
322impl StageSelection {
323 fn into_stages(self) -> (u32, u32) {
324 match self {
325 StageSelection::WithPlane {
326 plane_dim: plane_size,
327 num_plane: num_planes,
328 } => {
329 let num_units = num_planes * plane_size;
330 closest_factor_pair(num_units)
331 }
332 StageSelection::Fixed { m, n } => (m, n),
333 }
334 }
335}
336
337#[allow(clippy::too_many_arguments)]
338fn selection(
339 t: (u32, u32, u32),
340 p: (u32, u32, u32),
341 buffering: PartitionBuffering,
342 plane_dim: u32,
343 stage: StageSelection,
344 num_sms: Option<u32>,
345 global_order_config: GlobalOrderSelection,
346 stage_scaling: StageScaling,
347) -> MatmulSelection {
348 let (stage_size_m, stage_size_n) = stage.into_stages();
349
350 let (stage_size_m, stage_size_n) = match stage_scaling {
351 StageScaling::Enabled(f) => (stage_size_m / f as u32, stage_size_n / f as u32),
352 StageScaling::Disabled => (stage_size_m, stage_size_n),
353 };
354
355 let tiling_scheme = TilingScheme::builder()
356 .with_tile_size(t.into())
357 .with_partition_size(p.into())
358 .with_stage_size((stage_size_m, stage_size_n, 1).into())
359 .build()
360 .unwrap();
361
362 let cube_count_plan = match num_sms {
363 Some(num_sms) => CubeCountPlanSelection::Sm {
364 num_sms,
365 sm_usage: SmAllocation::Exact,
366 cubes_first: false,
367 },
368 None => CubeCountPlanSelection::Flattened,
369 };
370
371 let hypercube = HypercubeSelection::builder(&tiling_scheme)
372 .global_order(global_order_config)
373 .cube_count_plan(cube_count_plan)
374 .build();
375
376 MatmulSelection::builder(tiling_scheme, plane_dim)
377 .partition_buffering(buffering)
378 .hypercube_config(hypercube)
379 .build()
380}
381
382pub fn closest_factor_pair(n: u32) -> (u32, u32) {
385 let sqrt_n = (n as f64).sqrt() as u32;
386 for a in (1..=sqrt_n).rev() {
387 if n % a == 0 {
388 return (n / a, a);
389 }
390 }
391 (n, 1)
392}
393
394fn scale_partition(setting: PartitionScaling, axis: usize, max_exp: u32, div_exp: u32) -> u32 {
395 if let PartitionScaling::Disabled = setting {
396 return 2u32.pow(max_exp);
397 }
398
399 let exp = u32::min((axis as u32 / 2u32.pow(div_exp)) + 1, max_exp);
400 2u32.pow(exp)
401}