cubecl_matmul/kernels/layered/selector/
unit.rs

1use cubecl_core::{Runtime, client::ComputeClient};
2
3use crate::components::{
4    MatmulKind, MatmulProblem, MatmulSelection, MatrixLayout, TilingScheme,
5    batch::{CubeCountPlanSelection, GlobalOrderSelection, HypercubeSelection, SmAllocation},
6    stage::PartitionBuffering,
7};
8
9#[derive(Default, Clone, Copy, Debug)]
10pub enum TileSizeSelection {
11    // Choses the smallest tile size possible.
12    MinTileSize,
13    #[default]
14    // Choses the biggest tile size possible.
15    MaxTileSize,
16}
17
18#[derive(Default, Clone, Copy, Debug)]
19pub enum PartitionScaling {
20    #[default]
21    Enabled,
22    Disabled,
23}
24
25#[derive(Default, Clone, Copy, Debug)]
26pub enum StageScaling {
27    Enabled(u8),
28    #[default]
29    Disabled,
30}
31
32#[derive(Default, Clone, Copy, Debug)]
33pub struct UnitMatmulSelectionOptions {
34    pub tile: TileSizeSelection,
35    pub stage: StageScaling,
36    pub partition: PartitionScaling,
37}
38
39/// Computes a [MatmulSelection] depending on the problem kind
40pub fn unit_matmul_selection<R: Runtime>(
41    client: &ComputeClient<R::Server, R::Channel>,
42    problem: &MatmulProblem,
43    plane_dim: u32,
44    double_buffering: bool,
45    options: UnitMatmulSelectionOptions,
46) -> MatmulSelection {
47    let kind: MatmulKind = problem.into();
48    let num_sms = client.properties().hardware.num_streaming_multiprocessors;
49
50    match kind {
51        MatmulKind::General => {
52            general_unit_selector(problem, plane_dim, double_buffering, num_sms, options)
53        }
54        MatmulKind::MatVec => {
55            matvec_unit_selector(problem, plane_dim, double_buffering, num_sms, options)
56        }
57        MatmulKind::VecMat => vecmat_unit_selector(problem, plane_dim, double_buffering, num_sms),
58        MatmulKind::ScalarVec => {
59            scalarvec_unit_selector(problem, plane_dim, double_buffering, num_sms)
60        }
61        MatmulKind::VecScalar => {
62            vecscalar_unit_selector(problem, plane_dim, double_buffering, num_sms)
63        }
64        MatmulKind::InnerProduct => {
65            inner_product_unit_selector(problem, plane_dim, double_buffering, num_sms)
66        }
67        MatmulKind::OuterProduct => {
68            outer_product_unit_selector(problem, plane_dim, double_buffering, num_sms)
69        }
70        MatmulKind::ScalarProduct => {
71            scalar_product_unit_selector(problem, plane_dim, double_buffering, num_sms)
72        }
73    }
74}
75
76/// (M, K) @ (K, N) → (M, N), with M, K, N > 1
77fn general_unit_selector(
78    problem: &MatmulProblem,
79    plane_dim: u32,
80    double_buffering: bool,
81    num_sms: Option<u32>,
82    options: UnitMatmulSelectionOptions,
83) -> MatmulSelection {
84    use MatrixLayout::*;
85
86    // Manually tested for good performance on many shapes.
87    let (tile_size, mut partition_size) =
88        match (problem.lhs_layout, problem.rhs_layout, options.tile) {
89            (RowMajor, _, TileSizeSelection::MinTileSize) => (
90                (1, 4, 4),
91                (
92                    scale_partition(options.partition, problem.m, 4, 9),
93                    2,
94                    scale_partition(options.partition, problem.k, 2, 10),
95                ),
96            ),
97            (ColMajor, RowMajor, TileSizeSelection::MinTileSize) => (
98                (4, 4, 1),
99                (2, 2, scale_partition(options.partition, problem.k, 3, 10)),
100            ),
101            (ColMajor, ColMajor, _) | (_, _, TileSizeSelection::MaxTileSize) => (
102                (4, 4, 4),
103                (
104                    scale_partition(options.partition, problem.m, 2, 9),
105                    2,
106                    scale_partition(options.partition, problem.k, 2, 9),
107                ),
108            ),
109        };
110
111    // It seems to be faster, it's not a requirement of the algo.
112    if double_buffering && partition_size.2 > 2 {
113        partition_size.2 /= 2;
114    }
115
116    selection(
117        tile_size,
118        partition_size,
119        PartitionBuffering::Single,
120        plane_dim,
121        StageSelection::WithPlane {
122            plane_dim,
123            num_plane: 8,
124        },
125        num_sms,
126        GlobalOrderSelection::SwizzleRow {
127            m: problem.m as u32,
128            w: 4,
129        },
130        options.stage,
131    )
132}
133
134/// (M, K) @ (K, 1) → (M, 1)
135fn matvec_unit_selector(
136    problem: &MatmulProblem,
137    plane_dim: u32,
138    _double_buffering: bool,
139    num_sms: Option<u32>,
140    options: UnitMatmulSelectionOptions,
141) -> MatmulSelection {
142    use MatrixLayout::*;
143    let (tile_size, partition_size) = match (problem.lhs_layout, problem.rhs_layout, options.tile) {
144        (RowMajor, _, TileSizeSelection::MinTileSize) => ((1, 1, 4), (1, 1, 4)),
145        _ => ((4, 1, 4), (1, 1, 4)),
146    };
147
148    selection(
149        tile_size,
150        partition_size,
151        PartitionBuffering::Single,
152        plane_dim,
153        StageSelection::Fixed { m: 8, n: 8 },
154        num_sms,
155        GlobalOrderSelection::Default,
156        options.stage,
157    )
158}
159
160/// (1, K) @ (K, N) → (1, N)
161fn vecmat_unit_selector(
162    problem: &MatmulProblem,
163    plane_dim: u32,
164    _double_buffering: bool,
165    num_sms: Option<u32>,
166) -> MatmulSelection {
167    use MatrixLayout::*;
168    let (tile_size, partition_size) = match (problem.lhs_layout, problem.rhs_layout) {
169        (RowMajor, RowMajor) => ((1, 4, 4), (1, 1, 4)),
170        (RowMajor, ColMajor) => (
171            (1, 4, 4),
172            (2, 1, scale_partition(Default::default(), problem.k, 3, 7)),
173        ),
174        (ColMajor, RowMajor) => ((1, 4, 4), (1, 1, 4)),
175        (ColMajor, ColMajor) => (
176            (1, 4, 4),
177            (
178                2,
179                1,
180                scale_partition(PartitionScaling::Enabled, problem.k, 3, 7),
181            ),
182        ),
183    };
184
185    selection(
186        tile_size,
187        partition_size,
188        PartitionBuffering::Single,
189        plane_dim,
190        StageSelection::Fixed { m: 8, n: 8 },
191        num_sms,
192        GlobalOrderSelection::Default,
193        StageScaling::Disabled,
194    )
195}
196
197/// (1, 1) @ (1, N) → (1, N)
198fn scalarvec_unit_selector(
199    problem: &MatmulProblem,
200    plane_dim: u32,
201    _double_buffering: bool,
202    num_sms: Option<u32>,
203) -> MatmulSelection {
204    use MatrixLayout::*;
205    let (tile_size, partition_size) = match (problem.lhs_layout, problem.rhs_layout) {
206        (RowMajor, RowMajor) => ((1, 4, 4), (1, 2, 1)),
207        (RowMajor, ColMajor) => ((1, 4, 4), (1, 2, 1)),
208        (ColMajor, RowMajor) => ((1, 4, 4), (1, 2, 1)),
209        (ColMajor, ColMajor) => ((1, 4, 4), (2, 2, 1)),
210    };
211
212    selection(
213        tile_size,
214        partition_size,
215        PartitionBuffering::Single,
216        plane_dim,
217        StageSelection::Fixed { m: 4, n: 8 },
218        num_sms,
219        GlobalOrderSelection::Default,
220        StageScaling::Disabled,
221    )
222}
223
224/// (M, 1) @ (1, 1) → (M, 1)
225fn vecscalar_unit_selector(
226    _problem: &MatmulProblem,
227    plane_dim: u32,
228    _double_buffering: bool,
229    num_sms: Option<u32>,
230) -> MatmulSelection {
231    let (tile_size, partition_size) = ((4, 1, 4), (1, 1, 2));
232
233    selection(
234        tile_size,
235        partition_size,
236        PartitionBuffering::Single,
237        plane_dim,
238        StageSelection::Fixed { m: 8, n: 4 },
239        num_sms,
240        GlobalOrderSelection::Default,
241        StageScaling::Disabled,
242    )
243}
244
245/// (1, K) @ (K, 1) → (1, 1)
246fn inner_product_unit_selector(
247    problem: &MatmulProblem,
248    plane_dim: u32,
249    _double_buffering: bool,
250    num_sms: Option<u32>,
251) -> MatmulSelection {
252    use MatrixLayout::*;
253    let (tile_size, partition_size) = match (problem.lhs_layout, problem.rhs_layout) {
254        (RowMajor, RowMajor) => ((1, 1, 4), (1, 1, 1)),
255        (RowMajor, ColMajor) => ((1, 1, 4), (1, 1, 1)),
256        (ColMajor, RowMajor) => ((1, 1, 4), (1, 1, 1)),
257        (ColMajor, ColMajor) => ((1, 1, 4), (1, 1, 1)),
258    };
259
260    selection(
261        tile_size,
262        partition_size,
263        PartitionBuffering::Single,
264        plane_dim,
265        StageSelection::Fixed { m: 4, n: 8 },
266        num_sms,
267        GlobalOrderSelection::Default,
268        StageScaling::Disabled,
269    )
270}
271
272/// (M, 1) @ (1, N) → (M, N)
273fn outer_product_unit_selector(
274    _problem: &MatmulProblem,
275    plane_dim: u32,
276    _double_buffering: bool,
277    num_sms: Option<u32>,
278) -> MatmulSelection {
279    let (tile_size, partition_size) = ((4, 4, 1), (1, 1, 1));
280
281    selection(
282        tile_size,
283        partition_size,
284        PartitionBuffering::Single,
285        plane_dim,
286        StageSelection::Fixed { m: 8, n: 8 },
287        num_sms,
288        GlobalOrderSelection::Default,
289        StageScaling::Disabled,
290    )
291}
292
293/// (1, 1) @ (1, 1) → (1, 1)
294fn scalar_product_unit_selector(
295    _problem: &MatmulProblem,
296    plane_dim: u32,
297    _double_buffering: bool,
298    num_sms: Option<u32>,
299) -> MatmulSelection {
300    let (tile_size, partition_size) = ((1, 1, 1), (1, 1, 1));
301
302    selection(
303        tile_size,
304        partition_size,
305        PartitionBuffering::Single,
306        plane_dim,
307        StageSelection::WithPlane {
308            plane_dim,
309            num_plane: 8,
310        },
311        num_sms,
312        GlobalOrderSelection::Default,
313        StageScaling::Disabled,
314    )
315}
316
317enum StageSelection {
318    WithPlane { plane_dim: u32, num_plane: u32 },
319    Fixed { m: u32, n: u32 },
320}
321
322impl StageSelection {
323    fn into_stages(self) -> (u32, u32) {
324        match self {
325            StageSelection::WithPlane {
326                plane_dim: plane_size,
327                num_plane: num_planes,
328            } => {
329                let num_units = num_planes * plane_size;
330                closest_factor_pair(num_units)
331            }
332            StageSelection::Fixed { m, n } => (m, n),
333        }
334    }
335}
336
337#[allow(clippy::too_many_arguments)]
338fn selection(
339    t: (u32, u32, u32),
340    p: (u32, u32, u32),
341    buffering: PartitionBuffering,
342    plane_dim: u32,
343    stage: StageSelection,
344    num_sms: Option<u32>,
345    global_order_config: GlobalOrderSelection,
346    stage_scaling: StageScaling,
347) -> MatmulSelection {
348    let (stage_size_m, stage_size_n) = stage.into_stages();
349
350    let (stage_size_m, stage_size_n) = match stage_scaling {
351        StageScaling::Enabled(f) => (stage_size_m / f as u32, stage_size_n / f as u32),
352        StageScaling::Disabled => (stage_size_m, stage_size_n),
353    };
354
355    let tiling_scheme = TilingScheme::builder()
356        .with_tile_size(t.into())
357        .with_partition_size(p.into())
358        .with_stage_size((stage_size_m, stage_size_n, 1).into())
359        .build()
360        .unwrap();
361
362    let cube_count_plan = match num_sms {
363        Some(num_sms) => CubeCountPlanSelection::Sm {
364            num_sms,
365            sm_usage: SmAllocation::Exact,
366            cubes_first: false,
367        },
368        None => CubeCountPlanSelection::Flattened,
369    };
370
371    let hypercube = HypercubeSelection::builder(&tiling_scheme)
372        .global_order(global_order_config)
373        .cube_count_plan(cube_count_plan)
374        .build();
375
376    MatmulSelection::builder(tiling_scheme, plane_dim)
377        .partition_buffering(buffering)
378        .hypercube_config(hypercube)
379        .build()
380}
381
382/// Returns the factor pair `(a, b)` of `n` minimizing their difference,
383/// with `a >= b` and `a * b == n`.
384pub fn closest_factor_pair(n: u32) -> (u32, u32) {
385    let sqrt_n = (n as f64).sqrt() as u32;
386    for a in (1..=sqrt_n).rev() {
387        if n % a == 0 {
388            return (n / a, a);
389        }
390    }
391    (n, 1)
392}
393
394fn scale_partition(setting: PartitionScaling, axis: usize, max_exp: u32, div_exp: u32) -> u32 {
395    if let PartitionScaling::Disabled = setting {
396        return 2u32.pow(max_exp);
397    }
398
399    let exp = u32::min((axis as u32 / 2u32.pow(div_exp)) + 1, max_exp);
400    2u32.pow(exp)
401}