1use cubecl_core::{Runtime, client::ComputeClient, prelude::TensorHandleRef};
2
3use cubecl_std::tensor::TensorHandle;
4
5use crate::{
6 components::{MatmulSetupError, tile::accelerated::AcceleratedMatmul},
7 kernels::layered::{
8 Selection,
9 double_buffering::DoubleBufferingArgs,
10 double_unit::{DoubleUnitAlgorithm, DoubleUnitSelectionArgs},
11 ordered_double_buffering::OrderedSelectionArgs,
12 simple::SimpleArgs,
13 simple_unit::SimpleUnitSelectionArgs,
14 },
15};
16
17use super::{
18 components::{
19 MatmulPrecision,
20 global::load::{
21 async_full_cooperative, async_full_cyclic, async_full_maximize_slice_length,
22 async_full_maximize_unit_count, sync_full_strided, sync_full_tilewise,
23 },
24 stage::{ColMajorTilingOrder, RowMajorTilingOrder},
25 },
26 kernels::{
27 layered::{
28 self,
29 double_buffering::{
30 CyclicDoubleBufferingAlgorithm, HybridDoubleBufferingAlgorithm,
31 TilewiseDoubleBufferingAlgorithm,
32 },
33 ordered_double_buffering::OrderedDoubleBufferingAlgorithm,
34 simple::SimpleAlgorithm,
35 simple_barrier::SimpleBarrierAlgorithm,
36 simple_tma::SimpleTmaAlgorithm,
37 simple_unit::SimpleUnitAlgorithm,
38 },
39 naive,
40 },
41};
42
43#[derive(Debug, Clone, Default)]
44pub enum Strategy {
49 Simple(SyncLoadingStrategy, Selection<SimpleArgs>),
50 SimpleBarrier(AsyncLoadingStrategy),
51 DoubleBuffering(SyncPartialLoadingStrategy, Selection<DoubleBufferingArgs>),
52 SimpleUnit(Selection<SimpleUnitSelectionArgs>),
53 DoubleUnit(Selection<DoubleUnitSelectionArgs>),
54 OrderedDoubleBuffering(Selection<OrderedSelectionArgs>),
55 Naive,
56 #[default]
57 Auto,
59}
60
61#[derive(Debug, Clone)]
62pub enum SyncLoadingStrategy {
64 Cyclic,
65 Strided,
66 Tilewise,
67}
68
69#[derive(Debug, Clone)]
70pub enum SyncPartialLoadingStrategy {
72 Cyclic,
73 Tilewise,
74 Hybrid,
75}
76
77#[derive(Debug, Clone)]
78pub enum AsyncLoadingStrategy {
80 Cooperative,
81 Cyclic,
82 MaximizeSliceLength,
83 MaximizeUnitCount,
84 Tma,
85}
86
87#[allow(clippy::result_large_err)]
88pub fn launch<R: Runtime, MP: MatmulPrecision>(
89 strategy: &Strategy,
90 client: &ComputeClient<R::Server, R::Channel>,
91 lhs: TensorHandle<R, MP::EI>,
92 lhs_scale: Option<TensorHandle<R, f32>>,
93 rhs: TensorHandle<R, MP::EI>,
94 rhs_scale: Option<TensorHandle<R, f32>>,
95 out: TensorHandle<R, MP::EO>,
96) -> Result<(), MatmulSetupError> {
97 launch_ref::<R, MP>(
98 strategy,
99 client,
100 &lhs.as_ref(),
101 &lhs_scale.as_ref().map(|it| it.as_ref()),
102 &rhs.as_ref(),
103 &rhs_scale.as_ref().map(|it| it.as_ref()),
104 &out.as_ref(),
105 )
106}
107
108#[allow(clippy::result_large_err)]
109pub fn launch_ref<R: Runtime, MP: MatmulPrecision>(
110 strategy: &Strategy,
111 client: &ComputeClient<R::Server, R::Channel>,
112 lhs: &TensorHandleRef<R>,
113 lhs_scale: &Option<TensorHandleRef<R>>,
114 rhs: &TensorHandleRef<R>,
115 rhs_scale: &Option<TensorHandleRef<R>>,
116 out: &TensorHandleRef<R>,
117) -> Result<(), MatmulSetupError> {
118 match strategy {
119 Strategy::Simple(loading_strategy, selection) => match loading_strategy {
120 SyncLoadingStrategy::Cyclic => {
121 layered::launch_ref::<R, MP, SimpleAlgorithm<AcceleratedMatmul>>(
122 client, lhs, lhs_scale, rhs, rhs_scale, out, selection,
123 )
124 }
125 SyncLoadingStrategy::Strided => {
126 layered::launch_ref::<
127 R,
128 MP,
129 SimpleAlgorithm<
130 AcceleratedMatmul,
131 sync_full_strided::SyncFullStridedLoading,
132 sync_full_strided::SyncFullStridedLoading,
133 >,
134 >(client, lhs, lhs_scale, rhs, rhs_scale, out, selection)
135 }
136 SyncLoadingStrategy::Tilewise => layered::launch_ref::<
137 R,
138 MP,
139 SimpleAlgorithm<
140 AcceleratedMatmul,
141 sync_full_tilewise::SyncFullTilewiseLoading<ColMajorTilingOrder>,
142 sync_full_tilewise::SyncFullTilewiseLoading<RowMajorTilingOrder>,
143 >,
144 >(
145 client,
146 lhs,
147 lhs_scale,
148 rhs,
149 rhs_scale,
150 out,
151 &Default::default(),
152 ),
153 },
154 Strategy::SimpleBarrier(loading_strategy) => match loading_strategy {
155 AsyncLoadingStrategy::Cooperative => layered::launch_ref::<
156 R,
157 MP,
158 SimpleBarrierAlgorithm<
159 AcceleratedMatmul,
160 async_full_cooperative::AsyncFullCooperativeLoading,
161 >,
162 >(
163 client,
164 lhs,
165 lhs_scale,
166 rhs,
167 rhs_scale,
168 out,
169 &Default::default(),
170 ),
171 AsyncLoadingStrategy::Cyclic => layered::launch_ref::<
172 R,
173 MP,
174 SimpleBarrierAlgorithm<
175 AcceleratedMatmul,
176 async_full_cyclic::AsyncFullCyclicLoading<ColMajorTilingOrder>,
177 >,
178 >(
179 client,
180 lhs,
181 lhs_scale,
182 rhs,
183 rhs_scale,
184 out,
185 &Default::default(),
186 ),
187 AsyncLoadingStrategy::MaximizeSliceLength => layered::launch_ref::<
188 R,
189 MP,
190 SimpleBarrierAlgorithm<
191 AcceleratedMatmul,
192 async_full_maximize_slice_length::AsyncFullMaximizeSliceLengthLoading,
193 >,
194 >(
195 client,
196 lhs,
197 lhs_scale,
198 rhs,
199 rhs_scale,
200 out,
201 &Default::default(),
202 ),
203 AsyncLoadingStrategy::MaximizeUnitCount => layered::launch_ref::<
204 R,
205 MP,
206 SimpleBarrierAlgorithm<
207 AcceleratedMatmul,
208 async_full_maximize_unit_count::AsyncFullMaximizeUnitCountLoading,
209 >,
210 >(
211 client,
212 lhs,
213 lhs_scale,
214 rhs,
215 rhs_scale,
216 out,
217 &Default::default(),
218 ),
219 AsyncLoadingStrategy::Tma => {
220 layered::matmul_cmma_tma_ref_no_check::<R, MP, SimpleTmaAlgorithm<AcceleratedMatmul>>(
221 client,
222 lhs,
223 lhs_scale,
224 rhs,
225 rhs_scale,
226 out,
227 (false, false),
228 &Default::default(),
229 )
230 }
231 },
232 Strategy::DoubleBuffering(loading_strategy, selection) => match loading_strategy {
233 SyncPartialLoadingStrategy::Cyclic => {
234 layered::launch_ref::<R, MP, CyclicDoubleBufferingAlgorithm<AcceleratedMatmul>>(
235 client, lhs, lhs_scale, rhs, rhs_scale, out, selection,
236 )
237 }
238 SyncPartialLoadingStrategy::Tilewise => {
239 layered::launch_ref::<R, MP, TilewiseDoubleBufferingAlgorithm<AcceleratedMatmul>>(
240 client, lhs, lhs_scale, rhs, rhs_scale, out, selection,
241 )
242 }
243 SyncPartialLoadingStrategy::Hybrid => {
244 layered::launch_ref::<R, MP, HybridDoubleBufferingAlgorithm<AcceleratedMatmul>>(
245 client, lhs, lhs_scale, rhs, rhs_scale, out, selection,
246 )
247 }
248 },
249 Strategy::OrderedDoubleBuffering(selection) => {
250 layered::launch_ref::<R, MP, OrderedDoubleBufferingAlgorithm<AcceleratedMatmul>>(
251 client, lhs, lhs_scale, rhs, rhs_scale, out, selection,
252 )
253 }
254 Strategy::SimpleUnit(selection) => layered::launch_ref::<R, MP, SimpleUnitAlgorithm>(
255 client, lhs, lhs_scale, rhs, rhs_scale, out, selection,
256 ),
257 Strategy::DoubleUnit(selection) => layered::launch_ref::<R, MP, DoubleUnitAlgorithm>(
258 client, lhs, lhs_scale, rhs, rhs_scale, out, selection,
259 ),
260 Strategy::Naive => {
261 naive::launch_ref::<R, MP::EI>(client, lhs, rhs, out)?;
263 Ok(())
264 }
265 Strategy::Auto => {
266 if let Err(err) = layered::launch_ref::<R, MP, SimpleAlgorithm<AcceleratedMatmul>>(
267 client,
268 lhs,
269 lhs_scale,
270 rhs,
271 rhs_scale,
272 out,
273 &Default::default(),
274 ) {
275 match err {
276 MatmulSetupError::Unavailable(_) => {
277 layered::launch_ref::<R, MP, SimpleUnitAlgorithm>(
278 client,
279 lhs,
280 lhs_scale,
281 rhs,
282 rhs_scale,
283 out,
284 &Default::default(),
285 )
286 .unwrap();
287 }
288 _ => panic!("{err:?}"),
289 }
290 }
291
292 Ok(())
293 }
294 }
295}