cubecl_matmul/kernels/layered/algorithm/
simple.rs1use cubecl_core::{Runtime, client::ComputeClient};
2use cubecl_runtime::MmaConfig;
3use std::marker::PhantomData;
4
5use crate::{
6 components::{
7 MatmulElems, MatmulLineSizes, MatmulProblem, MatmulSelection, MatmulSetupError,
8 MultiRowStrategy, TilingScheme,
9 batch::{
10 CubeCountPlanSelection, GlobalOrderSelection, HypercubeSelection,
11 PartitionedBatchMatmulFamily, RowMajorGlobalPartitionMatmul, SmAllocation,
12 },
13 global::{
14 PlaneWriterFamily,
15 read::{SyncFullLoadingStrategy, sync_full_cyclic::SyncFullCyclicLoading},
16 single_stage::simple::SimpleMatmulFamily,
17 },
18 stage::{
19 ColMajorTilingOrder, FilledStageFamily, PartitionBuffering, PlaneMatmulFamily,
20 RowMajorTilingOrder, StridedStageFamily,
21 },
22 tile::{
23 TileMatmulFamily,
24 io::{Filled, Strided},
25 },
26 },
27 kernels::layered::{
28 Algorithm,
29 selector::{PlaneMatmulSelectionOptions, plane_matmul_selection},
30 },
31};
32
33pub struct SimpleAlgorithm<
35 TMM,
36 LL = SyncFullCyclicLoading<ColMajorTilingOrder>,
37 RL = SyncFullCyclicLoading<RowMajorTilingOrder>,
38> {
39 pub _tmm: PhantomData<TMM>,
40 pub _ll: PhantomData<LL>,
41 pub _rl: PhantomData<RL>,
42}
43
44#[derive(Default, Debug, Clone)]
45pub struct SimpleArgs {
46 pub multi_rows: bool,
48}
49
50impl<TMM, LL, RL> Algorithm for SimpleAlgorithm<TMM, LL, RL>
51where
52 TMM:
53 TileMatmulFamily<LhsTile = Strided, RhsTile = Strided, AccTile = Filled, OutTile = Strided>,
54 LL: SyncFullLoadingStrategy,
55 RL: SyncFullLoadingStrategy,
56{
57 type SelectionArgs = SimpleArgs;
58 type TileMatmul = TMM;
59 type StageMatmul = PlaneMatmulFamily<
60 Self::TileMatmul,
61 StridedStageFamily,
62 StridedStageFamily,
63 FilledStageFamily,
64 >;
65 type GlobalMatmul = SimpleMatmulFamily<Self::StageMatmul, LL, RL, PlaneWriterFamily>;
66 type BatchMatmul =
67 PartitionedBatchMatmulFamily<Self::GlobalMatmul, RowMajorGlobalPartitionMatmul>;
68
69 fn selection<R: Runtime>(
70 client: &ComputeClient<R::Server>,
71 problem: &MatmulProblem,
72 plane_dim: u32,
73 _line_sizes: &MatmulLineSizes,
74 elems: MatmulElems,
75 args: &Self::SelectionArgs,
76 ) -> Result<MatmulSelection, MatmulSetupError> {
77 if args.multi_rows {
78 selection_multi_rows::<R, TMM>(client, problem, plane_dim, elems)
79 } else {
80 plane_matmul_selection::<TMM, R>(
81 client,
82 problem,
83 plane_dim,
84 elems,
85 PlaneMatmulSelectionOptions {
86 partition_buffering: Some(PartitionBuffering::Single),
87 tiny_selection_enabled: true,
88 ..Default::default()
89 },
90 )
91 }
92 }
93}
94
95fn selection_multi_rows<R: Runtime, TMM: TileMatmulFamily>(
96 client: &ComputeClient<R::Server>,
97 problem: &MatmulProblem,
98 plane_dim: u32,
99 elems: MatmulElems,
100) -> Result<MatmulSelection, MatmulSetupError> {
101 let supported = |m: u32, n: u32, k: u32| {
102 client.properties().features.cmma.contains(&MmaConfig {
103 a_type: elems.lhs_register,
104 b_type: elems.rhs_register,
105 cd_type: elems.acc_register,
106 m,
107 n,
108 k,
109 })
110 };
111 let cube_count_plan = match client.properties().hardware.num_streaming_multiprocessors {
112 Some(num_sms) => CubeCountPlanSelection::Sm {
113 num_sms,
114 sm_usage: SmAllocation::Exact,
115 cubes_first: true,
116 },
117 None => CubeCountPlanSelection::Flattened,
118 };
119
120 if supported(8, 32, 16) {
121 let tiling_scheme = TilingScheme::builder()
124 .with_tile_size((8, 32, 16).into())
125 .with_partition_size((4, 4, 2).into())
126 .with_stage_size((4, 1, 1).into())
127 .build()
128 .unwrap();
129
130 let hypercube = HypercubeSelection::builder(&tiling_scheme)
131 .global_order(GlobalOrderSelection::SwizzleRow {
132 m: problem.m as u32,
133 w: 4,
134 })
135 .cube_count_plan(cube_count_plan)
136 .build();
137
138 Ok(MatmulSelection::builder(tiling_scheme, plane_dim)
139 .partition_buffering(PartitionBuffering::Single)
140 .hypercube_config(hypercube)
141 .build())
142 } else if supported(8, 8, 8) {
143 let tiling_scheme = TilingScheme::builder()
144 .with_tile_size((8, 8, 8).into())
145 .with_partition_size((4, 8, 2).into())
146 .with_stage_size((4, 1, 1).into())
147 .build()
148 .unwrap();
149 let hypercube = HypercubeSelection::builder(&tiling_scheme)
150 .global_order(GlobalOrderSelection::SwizzleRow {
151 m: problem.m as u32,
152 w: 4,
153 })
154 .cube_count_plan(cube_count_plan)
155 .build();
156
157 Ok(MatmulSelection::builder(tiling_scheme, plane_dim)
158 .partition_buffering(PartitionBuffering::Single)
159 .hypercube_config(hypercube)
160 .build())
161 } else {
162 plane_matmul_selection::<TMM, R>(
163 client,
164 problem,
165 plane_dim,
166 elems,
167 PlaneMatmulSelectionOptions {
168 partition_buffering: Some(PartitionBuffering::Single),
169 multi_row_strategy: MultiRowStrategy::Always(2),
170 partition_k: Some(2),
171 ..Default::default()
172 },
173 )
174 }
175}