cubecl_matmul/kernels/layered/algorithm/
simple.rs

1use cubecl_core::{Runtime, client::ComputeClient};
2use cubecl_runtime::MmaConfig;
3use std::marker::PhantomData;
4
5use crate::{
6    components::{
7        MatmulElems, MatmulLineSizes, MatmulProblem, MatmulSelection, MatmulSetupError,
8        MultiRowStrategy, TilingScheme,
9        batch::{
10            CubeCountPlanSelection, GlobalOrderSelection, HypercubeSelection,
11            PartitionedBatchMatmulFamily, RowMajorGlobalPartitionMatmul, SmAllocation,
12        },
13        global::{
14            PlaneWriterFamily,
15            read::{SyncFullLoadingStrategy, sync_full_cyclic::SyncFullCyclicLoading},
16            single_stage::simple::SimpleMatmulFamily,
17        },
18        stage::{
19            ColMajorTilingOrder, FilledStageFamily, PartitionBuffering, PlaneMatmulFamily,
20            RowMajorTilingOrder, StridedStageFamily,
21        },
22        tile::{
23            TileMatmulFamily,
24            io::{Filled, Strided},
25        },
26    },
27    kernels::layered::{
28        Algorithm,
29        selector::{PlaneMatmulSelectionOptions, plane_matmul_selection},
30    },
31};
32
33/// Plane accelerated single stage matmul with configurable readers (default to cyclic)
34pub struct SimpleAlgorithm<
35    TMM,
36    LL = SyncFullCyclicLoading<ColMajorTilingOrder>,
37    RL = SyncFullCyclicLoading<RowMajorTilingOrder>,
38> {
39    pub _tmm: PhantomData<TMM>,
40    pub _ll: PhantomData<LL>,
41    pub _rl: PhantomData<RL>,
42}
43
44#[derive(Default, Debug, Clone)]
45pub struct SimpleArgs {
46    // Uses an optimized multi rows strategy.
47    pub multi_rows: bool,
48}
49
50impl<TMM, LL, RL> Algorithm for SimpleAlgorithm<TMM, LL, RL>
51where
52    TMM:
53        TileMatmulFamily<LhsTile = Strided, RhsTile = Strided, AccTile = Filled, OutTile = Strided>,
54    LL: SyncFullLoadingStrategy,
55    RL: SyncFullLoadingStrategy,
56{
57    type SelectionArgs = SimpleArgs;
58    type TileMatmul = TMM;
59    type StageMatmul = PlaneMatmulFamily<
60        Self::TileMatmul,
61        StridedStageFamily,
62        StridedStageFamily,
63        FilledStageFamily,
64    >;
65    type GlobalMatmul = SimpleMatmulFamily<Self::StageMatmul, LL, RL, PlaneWriterFamily>;
66    type BatchMatmul =
67        PartitionedBatchMatmulFamily<Self::GlobalMatmul, RowMajorGlobalPartitionMatmul>;
68
69    fn selection<R: Runtime>(
70        client: &ComputeClient<R::Server>,
71        problem: &MatmulProblem,
72        plane_dim: u32,
73        _line_sizes: &MatmulLineSizes,
74        elems: MatmulElems,
75        args: &Self::SelectionArgs,
76    ) -> Result<MatmulSelection, MatmulSetupError> {
77        if args.multi_rows {
78            selection_multi_rows::<R, TMM>(client, problem, plane_dim, elems)
79        } else {
80            plane_matmul_selection::<TMM, R>(
81                client,
82                problem,
83                plane_dim,
84                elems,
85                PlaneMatmulSelectionOptions {
86                    partition_buffering: Some(PartitionBuffering::Single),
87                    tiny_selection_enabled: true,
88                    ..Default::default()
89                },
90            )
91        }
92    }
93}
94
95fn selection_multi_rows<R: Runtime, TMM: TileMatmulFamily>(
96    client: &ComputeClient<R::Server>,
97    problem: &MatmulProblem,
98    plane_dim: u32,
99    elems: MatmulElems,
100) -> Result<MatmulSelection, MatmulSetupError> {
101    let supported = |m: u32, n: u32, k: u32| {
102        client.properties().features.cmma.contains(&MmaConfig {
103            a_type: elems.lhs_register,
104            b_type: elems.rhs_register,
105            cd_type: elems.acc_register,
106            m,
107            n,
108            k,
109        })
110    };
111    let cube_count_plan = match client.properties().hardware.num_streaming_multiprocessors {
112        Some(num_sms) => CubeCountPlanSelection::Sm {
113            num_sms,
114            sm_usage: SmAllocation::Exact,
115            cubes_first: true,
116        },
117        None => CubeCountPlanSelection::Flattened,
118    };
119
120    if supported(8, 32, 16) {
121        // A lot of multi-rows balanced with a
122        // tile size of (8, 32, 16)
123        let tiling_scheme = TilingScheme::builder()
124            .with_tile_size((8, 32, 16).into())
125            .with_partition_size((4, 4, 2).into())
126            .with_stage_size((4, 1, 1).into())
127            .build()
128            .unwrap();
129
130        let hypercube = HypercubeSelection::builder(&tiling_scheme)
131            .global_order(GlobalOrderSelection::SwizzleRow {
132                m: problem.m as u32,
133                w: 4,
134            })
135            .cube_count_plan(cube_count_plan)
136            .build();
137
138        Ok(MatmulSelection::builder(tiling_scheme, plane_dim)
139            .partition_buffering(PartitionBuffering::Single)
140            .hypercube_config(hypercube)
141            .build())
142    } else if supported(8, 8, 8) {
143        let tiling_scheme = TilingScheme::builder()
144            .with_tile_size((8, 8, 8).into())
145            .with_partition_size((4, 8, 2).into())
146            .with_stage_size((4, 1, 1).into())
147            .build()
148            .unwrap();
149        let hypercube = HypercubeSelection::builder(&tiling_scheme)
150            .global_order(GlobalOrderSelection::SwizzleRow {
151                m: problem.m as u32,
152                w: 4,
153            })
154            .cube_count_plan(cube_count_plan)
155            .build();
156
157        Ok(MatmulSelection::builder(tiling_scheme, plane_dim)
158            .partition_buffering(PartitionBuffering::Single)
159            .hypercube_config(hypercube)
160            .build())
161    } else {
162        plane_matmul_selection::<TMM, R>(
163            client,
164            problem,
165            plane_dim,
166            elems,
167            PlaneMatmulSelectionOptions {
168                partition_buffering: Some(PartitionBuffering::Single),
169                multi_row_strategy: MultiRowStrategy::Always(2),
170                partition_k: Some(2),
171                ..Default::default()
172            },
173        )
174    }
175}