cubecl_matmul/kernels/layered/selector/
unit.rs

1use crate::components::{
2    MatmulKind, MatmulLineSizes, MatmulProblem, MatmulSelection, MatrixLayout, TilingScheme,
3    batch::{CubeCountPlanSelection, GlobalOrderSelection, HypercubeSelection, SmAllocation},
4    stage::PartitionBuffering,
5};
6use cubecl_core::{Runtime, client::ComputeClient};
7
8#[derive(Default, Clone, Copy, Debug)]
9pub enum TileSizeSelection {
10    // Chooses the smallest tile size possible.
11    MinTileSize,
12    #[default]
13    // Chooses the biggest tile size possible.
14    MaxTileSize,
15}
16
17#[derive(Default, Clone, Copy, Debug)]
18pub enum PartitionScaling {
19    #[default]
20    Enabled,
21    Disabled,
22}
23
24#[derive(Default, Clone, Copy, Debug)]
25pub enum StageScaling {
26    Enabled(u8),
27    #[default]
28    Disabled,
29}
30
31#[derive(Default, Clone, Copy, Debug)]
32pub struct UnitMatmulSelectionOptions {
33    pub tile: TileSizeSelection,
34    pub stage: StageScaling,
35    pub partition: PartitionScaling,
36}
37
38/// Computes a [MatmulSelection] depending on the problem kind
39pub fn unit_matmul_selection<R: Runtime>(
40    client: &ComputeClient<R::Server>,
41    problem: &MatmulProblem,
42    plane_dim: u32,
43    double_buffering: bool,
44    line_size: &MatmulLineSizes,
45    options: UnitMatmulSelectionOptions,
46) -> MatmulSelection {
47    let kind: MatmulKind = problem.into();
48    let num_sms = client.properties().hardware.num_streaming_multiprocessors;
49    let min_tile_size = u8::max(line_size.lhs, line_size.rhs);
50    let min_tile_size = u8::max(line_size.out, min_tile_size) as u32;
51    let tile_size = u32::max(min_tile_size, 4);
52
53    match kind {
54        MatmulKind::General => general_unit_selector(
55            problem,
56            plane_dim,
57            double_buffering,
58            tile_size,
59            num_sms,
60            options,
61        ),
62        MatmulKind::MatVec => matvec_unit_selector(
63            problem,
64            plane_dim,
65            double_buffering,
66            tile_size,
67            num_sms,
68            options,
69        ),
70        MatmulKind::VecMat => vecmat_unit_selector(
71            problem,
72            plane_dim,
73            double_buffering,
74            tile_size,
75            num_sms,
76            options,
77        ),
78        MatmulKind::ScalarVec => {
79            scalarvec_unit_selector(problem, plane_dim, double_buffering, tile_size, num_sms)
80        }
81        MatmulKind::VecScalar => {
82            vecscalar_unit_selector(problem, plane_dim, double_buffering, tile_size, num_sms)
83        }
84        MatmulKind::InnerProduct => {
85            inner_product_unit_selector(problem, plane_dim, double_buffering, tile_size, num_sms)
86        }
87        MatmulKind::OuterProduct => {
88            outer_product_unit_selector(problem, plane_dim, double_buffering, tile_size, num_sms)
89        }
90        MatmulKind::ScalarProduct => {
91            scalar_product_unit_selector(problem, plane_dim, double_buffering, tile_size, num_sms)
92        }
93    }
94}
95
96/// (M, K) @ (K, N) → (M, N), with M, K, N > 1
97fn general_unit_selector(
98    problem: &MatmulProblem,
99    plane_dim: u32,
100    double_buffering: bool,
101    tile_size: u32,
102    num_sms: Option<u32>,
103    options: UnitMatmulSelectionOptions,
104) -> MatmulSelection {
105    use MatrixLayout::*;
106
107    // Manually tested for good performance on many shapes.
108    let (tile_size, mut partition_size) =
109        match (problem.lhs_layout, problem.rhs_layout, options.tile) {
110            (RowMajor, _, TileSizeSelection::MinTileSize) => (
111                (1, tile_size, tile_size),
112                (
113                    scale_partition(options.partition, problem.m, 4, 9),
114                    2,
115                    scale_partition(options.partition, problem.k, 2, 10),
116                ),
117            ),
118            (ColMajor, RowMajor, TileSizeSelection::MinTileSize) => (
119                (tile_size, tile_size, 1),
120                (2, 2, scale_partition(options.partition, problem.k, 3, 10)),
121            ),
122            (ColMajor, ColMajor, _) | (_, _, TileSizeSelection::MaxTileSize) => (
123                (tile_size, tile_size, tile_size),
124                (
125                    scale_partition(options.partition, problem.m, 2, 9),
126                    2,
127                    scale_partition(options.partition, problem.k, 2, 9),
128                ),
129            ),
130        };
131
132    let mut num_plane = 8;
133
134    if double_buffering {
135        if partition_size.0 > 2 {
136            partition_size.0 /= 2;
137        }
138        if partition_size.2 > 2 {
139            partition_size.2 /= 2;
140        }
141        num_plane /= 2;
142    }
143
144    selection(
145        tile_size,
146        partition_size,
147        PartitionBuffering::Single,
148        plane_dim,
149        StageSelection::WithPlane {
150            plane_dim,
151            num_plane,
152        },
153        num_sms,
154        GlobalOrderSelection::SwizzleRow {
155            m: problem.m as u32,
156            w: 4,
157        },
158        options.stage,
159    )
160}
161
162/// (M, K) @ (K, 1) → (M, 1)
163fn matvec_unit_selector(
164    problem: &MatmulProblem,
165    plane_dim: u32,
166    _double_buffering: bool,
167    tile_size: u32,
168    num_sms: Option<u32>,
169    _options: UnitMatmulSelectionOptions,
170) -> MatmulSelection {
171    let (tile_size, partition_size) = match (problem.lhs_layout, problem.rhs_layout) {
172        (MatrixLayout::RowMajor, _) => ((1, 1, tile_size), (1, 1, tile_size * 2)),
173        _ => ((tile_size, 1, tile_size), (1, 1, 1)),
174    };
175
176    selection(
177        tile_size,
178        partition_size,
179        PartitionBuffering::Single,
180        plane_dim,
181        StageSelection::Fixed {
182            m: plane_dim / 2,
183            n: 2,
184        },
185        num_sms,
186        GlobalOrderSelection::Default,
187        StageScaling::Disabled,
188    )
189}
190
191/// (1, K) @ (K, N) → (1, N)
192fn vecmat_unit_selector(
193    _problem: &MatmulProblem,
194    plane_dim: u32,
195    _double_buffering: bool,
196    tile_size: u32,
197    num_sms: Option<u32>,
198    _options: UnitMatmulSelectionOptions,
199) -> MatmulSelection {
200    let (tile_size, partition_size) = ((1, tile_size, tile_size), (1, 1, 1));
201
202    selection(
203        tile_size,
204        partition_size,
205        PartitionBuffering::Single,
206        plane_dim,
207        StageSelection::Fixed {
208            m: 2,
209            n: plane_dim / 2,
210        },
211        num_sms,
212        GlobalOrderSelection::Default,
213        StageScaling::Disabled,
214    )
215}
216
217/// (1, 1) @ (1, N) → (1, N)
218fn scalarvec_unit_selector(
219    problem: &MatmulProblem,
220    plane_dim: u32,
221    _double_buffering: bool,
222    tile_size: u32,
223    num_sms: Option<u32>,
224) -> MatmulSelection {
225    use MatrixLayout::*;
226    let (tile_size, partition_size) = match (problem.lhs_layout, problem.rhs_layout) {
227        (RowMajor, RowMajor) => ((1, tile_size, tile_size), (1, 2, 1)),
228        (RowMajor, ColMajor) => ((1, tile_size, tile_size), (1, 2, 1)),
229        (ColMajor, RowMajor) => ((1, tile_size, tile_size), (1, 2, 1)),
230        (ColMajor, ColMajor) => ((1, tile_size, tile_size), (2, 2, 1)),
231    };
232
233    selection(
234        tile_size,
235        partition_size,
236        PartitionBuffering::Single,
237        plane_dim,
238        StageSelection::Fixed {
239            m: 2,
240            n: plane_dim / 2,
241        },
242        num_sms,
243        GlobalOrderSelection::Default,
244        StageScaling::Disabled,
245    )
246}
247
248/// (M, 1) @ (1, 1) → (M, 1)
249fn vecscalar_unit_selector(
250    _problem: &MatmulProblem,
251    plane_dim: u32,
252    _double_buffering: bool,
253    tile_size: u32,
254    num_sms: Option<u32>,
255) -> MatmulSelection {
256    let (tile_size, partition_size) = ((tile_size, 1, 1), (1, 1, 1));
257
258    selection(
259        tile_size,
260        partition_size,
261        PartitionBuffering::Single,
262        plane_dim,
263        StageSelection::Fixed {
264            m: plane_dim / 2,
265            n: 2,
266        },
267        num_sms,
268        GlobalOrderSelection::Default,
269        StageScaling::Disabled,
270    )
271}
272
273/// (1, K) @ (K, 1) → (1, 1)
274fn inner_product_unit_selector(
275    problem: &MatmulProblem,
276    plane_dim: u32,
277    _double_buffering: bool,
278    tile_size: u32,
279    num_sms: Option<u32>,
280) -> MatmulSelection {
281    use MatrixLayout::*;
282    let (tile_size, partition_size) = match (problem.lhs_layout, problem.rhs_layout) {
283        (RowMajor, RowMajor) => ((1, 1, tile_size), (1, 1, 1)),
284        (RowMajor, ColMajor) => ((1, 1, tile_size), (1, 1, 1)),
285        (ColMajor, RowMajor) => ((1, 1, tile_size), (1, 1, 1)),
286        (ColMajor, ColMajor) => ((1, 1, tile_size), (1, 1, 1)),
287    };
288
289    selection(
290        tile_size,
291        partition_size,
292        PartitionBuffering::Single,
293        plane_dim,
294        StageSelection::Fixed { m: plane_dim, n: 1 }, // TODO: most planes does nothing.
295        num_sms,
296        GlobalOrderSelection::Default,
297        StageScaling::Disabled,
298    )
299}
300
301/// (M, 1) @ (1, N) → (M, N)
302fn outer_product_unit_selector(
303    _problem: &MatmulProblem,
304    plane_dim: u32,
305    _double_buffering: bool,
306    tile_size: u32,
307    num_sms: Option<u32>,
308) -> MatmulSelection {
309    let (tile_size, partition_size) = ((tile_size, tile_size, 1), (1, 1, 1));
310
311    selection(
312        tile_size,
313        partition_size,
314        PartitionBuffering::Single,
315        plane_dim,
316        StageSelection::Fixed { m: 8, n: 8 },
317        num_sms,
318        GlobalOrderSelection::Default,
319        StageScaling::Disabled,
320    )
321}
322
323/// (1, 1) @ (1, 1) → (1, 1)
324fn scalar_product_unit_selector(
325    _problem: &MatmulProblem,
326    plane_dim: u32,
327    _double_buffering: bool,
328    _tile_size: u32,
329    num_sms: Option<u32>,
330) -> MatmulSelection {
331    let (tile_size, partition_size) = ((1, 1, 1), (1, 1, 1));
332
333    selection(
334        tile_size,
335        partition_size,
336        PartitionBuffering::Single,
337        plane_dim,
338        StageSelection::WithPlane {
339            plane_dim,
340            num_plane: 1,
341        },
342        num_sms,
343        GlobalOrderSelection::Default,
344        StageScaling::Disabled,
345    )
346}
347
348enum StageSelection {
349    WithPlane { plane_dim: u32, num_plane: u32 },
350    Fixed { m: u32, n: u32 },
351}
352
353impl StageSelection {
354    fn into_stages(self) -> (u32, u32) {
355        match self {
356            StageSelection::WithPlane {
357                plane_dim: plane_size,
358                num_plane: num_planes,
359            } => {
360                let num_units = num_planes * plane_size;
361                closest_factor_pair(num_units)
362            }
363            StageSelection::Fixed { m, n } => (m, n),
364        }
365    }
366}
367
368#[allow(clippy::too_many_arguments)]
369fn selection(
370    t: (u32, u32, u32),
371    p: (u32, u32, u32),
372    buffering: PartitionBuffering,
373    plane_dim: u32,
374    stage: StageSelection,
375    num_sms: Option<u32>,
376    global_order_config: GlobalOrderSelection,
377    stage_scaling: StageScaling,
378) -> MatmulSelection {
379    let (stage_size_m, stage_size_n) = stage.into_stages();
380
381    let (stage_size_m, stage_size_n) = match stage_scaling {
382        StageScaling::Enabled(f) => (stage_size_m / f as u32, stage_size_n / f as u32),
383        StageScaling::Disabled => (stage_size_m, stage_size_n),
384    };
385
386    let tiling_scheme = TilingScheme::builder()
387        .with_tile_size(t.into())
388        .with_partition_size(p.into())
389        .with_stage_size((stage_size_m, stage_size_n, 1).into())
390        .build()
391        .unwrap();
392
393    let cube_count_plan = match num_sms {
394        Some(num_sms) => CubeCountPlanSelection::Sm {
395            num_sms,
396            sm_usage: SmAllocation::Exact,
397            cubes_first: false,
398        },
399        None => CubeCountPlanSelection::Flattened,
400    };
401
402    let hypercube = HypercubeSelection::builder(&tiling_scheme)
403        .global_order(global_order_config)
404        .cube_count_plan(cube_count_plan)
405        .build();
406
407    MatmulSelection::builder(tiling_scheme, plane_dim)
408        .partition_buffering(buffering)
409        .hypercube_config(hypercube)
410        .build()
411}
412
413/// Returns the factor pair `(a, b)` of `n` minimizing their difference,
414/// with `a >= b` and `a * b == n`.
415pub fn closest_factor_pair(n: u32) -> (u32, u32) {
416    let sqrt_n = (n as f64).sqrt() as u32;
417    for a in (1..=sqrt_n).rev() {
418        if n.is_multiple_of(a) {
419            return (n / a, a);
420        }
421    }
422    (n, 1)
423}
424
425fn scale_partition(setting: PartitionScaling, axis: usize, max_exp: u32, div_exp: u32) -> u32 {
426    if let PartitionScaling::Disabled = setting {
427        return 2u32.pow(max_exp);
428    }
429
430    let exp = u32::min((axis as u32 / 2u32.pow(div_exp)) + 1, max_exp);
431    2u32.pow(exp)
432}