1use cubecl_common::quant::scheme::{QuantScheme, QuantStore, QuantValue};
2use cubecl_core::{
3 Runtime,
4 client::ComputeClient,
5 ir::StorageType,
6 prelude::{CubePrimitive, TensorHandleRef},
7};
8
9use cubecl_std::tensor::{TensorHandle, into_contiguous_packed, into_contiguous_pitched};
10use serde::{Deserialize, Serialize};
11
12use crate::{
13 components::{
14 MatmulElems, MatmulSetupError,
15 tile::{cmma::CmmaMatmul, io::Filled, mma::MmaMatmul},
16 },
17 kernels::layered::{
18 Selection,
19 double_buffering::{DoubleBufferingArgs, TmaDoubleBufferingAlgorithm},
20 double_unit::{DoubleUnitAlgorithm, DoubleUnitSelectionArgs},
21 ordered_double_buffering::OrderedSelectionArgs,
22 simple::SimpleArgs,
23 simple_unit::SimpleUnitSelectionArgs,
24 specialized::TmaSpecializedAlgorithm,
25 vecmat::{DoubleVecMatAlgorithm, SimpleVecMatAlgorithm},
26 },
27};
28
29use super::{
30 components::{
31 global::read::{
32 async_full_cooperative, async_full_cyclic, async_full_maximize_slice_length,
33 async_full_maximize_unit_count, sync_full_strided, sync_full_tilewise,
34 },
35 stage::{ColMajorTilingOrder, RowMajorTilingOrder},
36 },
37 kernels::{
38 layered::{
39 self,
40 double_buffering::{
41 CyclicDoubleBufferingAlgorithm, HybridDoubleBufferingAlgorithm,
42 TilewiseDoubleBufferingAlgorithm,
43 },
44 ordered_double_buffering::OrderedDoubleBufferingAlgorithm,
45 simple::{SimpleAlgorithm, SimpleTmaAlgorithm},
46 simple_unit::SimpleUnitAlgorithm,
47 },
48 naive,
49 },
50};
51
52#[derive(Debug, Clone, Default)]
53pub enum Strategy {
58 Simple {
59 read_strategy: ReadingStrategy,
60 selection: Selection<SimpleArgs>,
61 tile_kind: AcceleratedTileKind,
62 },
63 DoubleBuffering {
64 read_strategy: PartialReadingStrategy,
65 selection: Selection<DoubleBufferingArgs>,
66 tile_kind: AcceleratedTileKind,
67 },
68 Specialized {
69 selection: Selection<()>,
70 tile_kind: AcceleratedTileKind,
71 },
72 SimpleUnit(Selection<SimpleUnitSelectionArgs>),
73 DoubleUnit(Selection<DoubleUnitSelectionArgs>),
74 SimpleVecMat(Selection<()>),
75 DoubleVecMat(Selection<()>),
76 OrderedDoubleBuffering {
77 selection: Selection<OrderedSelectionArgs>,
78 tile_kind: AcceleratedTileKind,
79 },
80 Naive,
81 #[default]
82 Auto,
84}
85
86#[derive(Debug, Clone, Copy)]
87pub enum ReadingStrategy {
89 Cyclic,
90 Strided,
91 Tilewise,
92 AsyncCooperative,
93 AsyncCyclic,
94 AsyncMaximizeSliceLength,
95 AsyncMaximizeUnitCount,
96 Tma,
97}
98
99#[derive(Debug, Clone, Copy)]
100pub enum PartialReadingStrategy {
102 Cyclic,
103 Tilewise,
104 Hybrid,
105 Tma,
106}
107
108#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
109pub enum AcceleratedTileKind {
111 #[default]
112 Cmma,
113 Mma,
114}
115
116macro_rules! with_tile_kind {
117 ($kind: expr, $T: ident, $launch: expr) => {
118 match $kind {
119 AcceleratedTileKind::Cmma => {
120 type $T = CmmaMatmul<Filled>;
121 ($launch)()
122 }
123 AcceleratedTileKind::Mma => {
124 type $T = MmaMatmul;
125 ($launch)()
126 }
127 }
128 };
129}
130
131pub enum MatmulInputHandle<R: Runtime> {
132 Normal(TensorHandle<R>),
133 Quantized {
134 data: TensorHandle<R>,
135 scale: TensorHandle<R>,
136 shape: Vec<usize>,
137 scheme: QuantScheme,
138 },
139}
140
141impl<R: Runtime> MatmulInputHandle<R> {
142 pub fn as_ref(&self) -> MatmulInputHandleRef<'_, R> {
143 match self {
144 MatmulInputHandle::Normal(handle) => {
145 MatmulInputHandleRef::Normal(handle.as_ref(), handle.dtype)
146 }
147 MatmulInputHandle::Quantized {
148 data,
149 scale,
150 shape,
151 scheme,
152 } => MatmulInputHandleRef::Quantized {
153 data: data.as_ref(),
154 scale: scale.as_ref(),
155 data_dtype: data.dtype,
156 scale_dtype: scale.dtype,
157 shape,
158 scheme,
159 },
160 }
161 }
162
163 pub fn from_ref(handle: &MatmulInputHandleRef<'_, R>) -> Self {
164 match handle {
165 MatmulInputHandleRef::Normal(handle, dtype) => {
166 MatmulInputHandle::Normal(TensorHandle::from_ref(handle, *dtype))
167 }
168 MatmulInputHandleRef::Quantized {
169 data,
170 scale,
171 shape,
172 scheme,
173 data_dtype,
174 scale_dtype,
175 } => MatmulInputHandle::Quantized {
176 data: TensorHandle::from_ref(data, *data_dtype),
177 scale: TensorHandle::from_ref(scale, *scale_dtype),
178 shape: shape.to_vec(),
179 scheme: **scheme,
180 },
181 }
182 }
183
184 pub fn data(&self) -> &TensorHandle<R> {
185 match self {
186 MatmulInputHandle::Normal(handle) => handle,
187 MatmulInputHandle::Quantized { data, .. } => data,
188 }
189 }
190
191 pub fn swap_dims(&mut self, dim0: usize, dim1: usize) {
192 match self {
193 MatmulInputHandle::Normal(handle) => {
194 handle.shape.swap(dim0, dim1);
195 handle.strides.swap(dim0, dim1);
196 }
197 MatmulInputHandle::Quantized {
198 data, scale, shape, ..
199 } => {
200 data.shape.swap(dim0, dim1);
201 data.strides.swap(dim0, dim1);
202 if scale.shape.len() == data.shape.len() {
203 scale.shape.swap(dim0, dim1);
204 scale.strides.swap(dim0, dim1);
205 }
206 shape.swap(dim0, dim1);
207 }
208 }
209 }
210}
211
212impl<R: Runtime> Clone for MatmulInputHandle<R> {
213 fn clone(&self) -> Self {
214 match self {
215 Self::Normal(handle) => Self::Normal(handle.clone()),
216 Self::Quantized {
217 data,
218 scale,
219 shape,
220 scheme,
221 } => Self::Quantized {
222 data: data.clone(),
223 scale: scale.clone(),
224 shape: shape.clone(),
225 scheme: *scheme,
226 },
227 }
228 }
229}
230
231#[derive(Debug)]
232pub enum MatmulInputHandleRef<'a, R: Runtime> {
233 Normal(TensorHandleRef<'a, R>, StorageType),
234 Quantized {
235 data: TensorHandleRef<'a, R>,
236 data_dtype: StorageType,
237 scale: TensorHandleRef<'a, R>,
238 scale_dtype: StorageType,
239 shape: &'a [usize],
241 scheme: &'a QuantScheme,
242 },
243}
244
245impl<'a, R: Runtime> Clone for MatmulInputHandleRef<'a, R> {
246 fn clone(&self) -> Self {
247 *self
248 }
249}
250
251impl<'a, R: Runtime> Copy for MatmulInputHandleRef<'a, R> {}
252
253impl<'a, R: Runtime> MatmulInputHandleRef<'a, R> {
254 pub fn new(data: TensorHandleRef<'a, R>, dtype: StorageType) -> Self {
255 Self::Normal(data, dtype)
256 }
257
258 pub fn quantized(
259 data: TensorHandleRef<'a, R>,
260 scale: TensorHandleRef<'a, R>,
261 shape: &'a [usize],
262 scheme: &'a QuantScheme,
263 data_dtype: StorageType,
264 scale_dtype: StorageType,
265 ) -> Self {
266 Self::Quantized {
267 data,
268 scale,
269 shape,
270 scheme,
271 data_dtype,
272 scale_dtype,
273 }
274 }
275
276 pub fn data(&self) -> &TensorHandleRef<'a, R> {
277 match self {
278 MatmulInputHandleRef::Normal(handle, ..) => handle,
279 MatmulInputHandleRef::Quantized { data, .. } => data,
280 }
281 }
282
283 pub fn data_mut(&mut self) -> &mut TensorHandleRef<'a, R> {
284 match self {
285 MatmulInputHandleRef::Normal(handle, ..) => handle,
286 MatmulInputHandleRef::Quantized { data, .. } => data,
287 }
288 }
289
290 pub fn scale(&self) -> Option<&TensorHandleRef<'a, R>> {
291 match self {
292 MatmulInputHandleRef::Normal(..) => None,
293 MatmulInputHandleRef::Quantized { scale, .. } => Some(scale),
294 }
295 }
296
297 pub fn scheme(&self) -> Option<&QuantScheme> {
298 match self {
299 MatmulInputHandleRef::Normal(..) => None,
300 MatmulInputHandleRef::Quantized { scheme, .. } => Some(scheme),
301 }
302 }
303
304 pub fn shape(&self) -> &[usize] {
305 match self {
306 MatmulInputHandleRef::Normal(handle, ..) => handle.shape,
307 MatmulInputHandleRef::Quantized { shape, .. } => shape,
308 }
309 }
310
311 pub fn into_contiguous(&self, client: &ComputeClient<R::Server>) -> MatmulInputHandle<R> {
312 match self {
313 MatmulInputHandleRef::Normal(data, dtype) => {
314 MatmulInputHandle::Normal(into_contiguous_pitched::<R>(client, data, *dtype))
315 }
316 MatmulInputHandleRef::Quantized {
317 data,
318 scale,
319 shape,
320 scheme,
321 data_dtype,
322 scale_dtype,
323 } => {
324 let data = match scheme.store {
325 QuantStore::Native if scheme.value == QuantValue::E2M1 => {
327 let data = into_contiguous_packed::<R>(
328 client,
329 data,
330 shape,
331 2,
332 u8::as_type_native_unchecked(),
333 );
334 TensorHandle::from_ref(&data.as_ref(), *data_dtype)
336 }
337 QuantStore::U32 => {
338 let data = into_contiguous_packed::<R>(
339 client,
340 data,
341 shape,
342 scheme.num_quants() as u32,
343 u32::as_type_native_unchecked(),
344 );
345 TensorHandle::from_ref(&data.as_ref(), *data_dtype)
347 }
348 _ => into_contiguous_pitched::<R>(client, data, *data_dtype),
349 };
350 MatmulInputHandle::Quantized {
351 data,
352 scale: TensorHandle::from_ref(scale, *scale_dtype),
353 shape: shape.to_vec(),
354 scheme: **scheme,
355 }
356 }
357 }
358 }
359}
360
361#[allow(clippy::result_large_err)]
362pub fn launch<R: Runtime>(
363 strategy: &Strategy,
364 client: &ComputeClient<R::Server>,
365 lhs: MatmulInputHandle<R>,
366 rhs: MatmulInputHandle<R>,
367 out: TensorHandle<R>,
368 mut dtypes: MatmulElems,
369) -> Result<(), MatmulSetupError> {
370 launch_ref::<R>(
371 strategy,
372 client,
373 &lhs.as_ref(),
374 &rhs.as_ref(),
375 &out.as_ref(),
376 &mut dtypes,
377 )
378}
379
380#[allow(clippy::result_large_err)]
381pub fn launch_ref<R: Runtime>(
389 strategy: &Strategy,
390 client: &ComputeClient<R::Server>,
391 lhs: &MatmulInputHandleRef<R>,
392 rhs: &MatmulInputHandleRef<R>,
393 out: &TensorHandleRef<R>,
394 dtypes: &mut MatmulElems,
395) -> Result<(), MatmulSetupError> {
396 match strategy {
397 Strategy::Simple {
398 read_strategy,
399 selection,
400 tile_kind,
401 } => with_tile_kind!(tile_kind, Accelerated, || match read_strategy {
402 ReadingStrategy::Cyclic => {
403 layered::launch_ref::<R, SimpleAlgorithm<Accelerated>>(
404 client, lhs, rhs, out, selection, dtypes,
405 )
406 }
407 ReadingStrategy::Strided => layered::launch_ref::<
408 R,
409 SimpleAlgorithm<
410 Accelerated,
411 sync_full_strided::SyncFullStridedLoading,
412 sync_full_strided::SyncFullStridedLoading,
413 >,
414 >(client, lhs, rhs, out, selection, dtypes),
415 ReadingStrategy::Tilewise => {
416 layered::launch_ref::<
417 R,
418 SimpleAlgorithm<
419 Accelerated,
420 sync_full_tilewise::SyncFullTilewiseLoading<ColMajorTilingOrder>,
421 sync_full_tilewise::SyncFullTilewiseLoading<RowMajorTilingOrder>,
422 >,
423 >(client, lhs, rhs, out, selection, dtypes)
424 }
425 ReadingStrategy::AsyncCooperative => {
426 layered::launch_ref::<
427 R,
428 SimpleAlgorithm<
429 Accelerated,
430 async_full_cooperative::AsyncFullCooperativeLoading,
431 async_full_cooperative::AsyncFullCooperativeLoading,
432 >,
433 >(client, lhs, rhs, out, selection, dtypes)
434 }
435 ReadingStrategy::AsyncCyclic => {
436 layered::launch_ref::<
437 R,
438 SimpleAlgorithm<
439 Accelerated,
440 async_full_cyclic::AsyncFullCyclicLoading<ColMajorTilingOrder>,
441 async_full_cyclic::AsyncFullCyclicLoading<RowMajorTilingOrder>,
442 >,
443 >(client, lhs, rhs, out, selection, dtypes)
444 }
445 ReadingStrategy::AsyncMaximizeSliceLength => {
446 layered::launch_ref::<
447 R,
448 SimpleAlgorithm<
449 Accelerated,
450 async_full_maximize_slice_length::AsyncFullMaximizeSliceLengthLoading,
451 async_full_maximize_slice_length::AsyncFullMaximizeSliceLengthLoading,
452 >,
453 >(client, lhs, rhs, out, &Default::default(), dtypes)
454 }
455 ReadingStrategy::AsyncMaximizeUnitCount => {
456 layered::launch_ref::<
457 R,
458 SimpleAlgorithm<
459 Accelerated,
460 async_full_maximize_unit_count::AsyncFullMaximizeUnitCountLoading,
461 async_full_maximize_unit_count::AsyncFullMaximizeUnitCountLoading,
462 >,
463 >(client, lhs, rhs, out, &Default::default(), dtypes)
464 }
465 ReadingStrategy::Tma => layered::launch_ref_tma::<R, SimpleTmaAlgorithm<Accelerated>>(
466 client, lhs, rhs, out, selection, dtypes
467 ),
468 }),
469 Strategy::DoubleBuffering {
470 read_strategy,
471 selection,
472 tile_kind,
473 } => with_tile_kind!(tile_kind, Accelerated, || match read_strategy {
474 PartialReadingStrategy::Cyclic => {
475 layered::launch_ref::<R, CyclicDoubleBufferingAlgorithm<Accelerated>>(
476 client, lhs, rhs, out, selection, dtypes,
477 )
478 }
479 PartialReadingStrategy::Tilewise => {
480 layered::launch_ref::<R, TilewiseDoubleBufferingAlgorithm<Accelerated>>(
481 client, lhs, rhs, out, selection, dtypes,
482 )
483 }
484 PartialReadingStrategy::Hybrid => {
485 layered::launch_ref::<R, HybridDoubleBufferingAlgorithm<Accelerated>>(
486 client, lhs, rhs, out, selection, dtypes,
487 )
488 }
489 PartialReadingStrategy::Tma => {
490 layered::launch_ref_tma::<R, TmaDoubleBufferingAlgorithm<Accelerated>>(
491 client, lhs, rhs, out, selection, dtypes,
492 )
493 }
494 }),
495 Strategy::Specialized {
496 selection,
497 tile_kind,
498 } => with_tile_kind!(tile_kind, Accelerated, || layered::launch_ref_tma::<
499 R,
500 TmaSpecializedAlgorithm<Accelerated>,
501 >(
502 client, lhs, rhs, out, selection, dtypes
503 )),
504 Strategy::OrderedDoubleBuffering {
505 selection,
506 tile_kind,
507 } => with_tile_kind!(tile_kind, Accelerated, || layered::launch_ref::<
508 R,
509 OrderedDoubleBufferingAlgorithm<Accelerated>,
510 >(
511 client, lhs, rhs, out, selection, dtypes
512 )),
513 Strategy::SimpleUnit(selection) => {
514 layered::launch_ref::<R, SimpleUnitAlgorithm>(client, lhs, rhs, out, selection, dtypes)
515 }
516 Strategy::DoubleUnit(selection) => {
517 layered::launch_ref::<R, DoubleUnitAlgorithm>(client, lhs, rhs, out, selection, dtypes)
518 }
519 Strategy::Naive => {
520 naive::launch_ref::<R>(client, lhs, rhs, out, dtypes)?;
521 Ok(())
522 }
523 Strategy::Auto => {
524 if let Err(err) = layered::launch_ref::<R, SimpleAlgorithm<CmmaMatmul<Filled>>>(
525 client,
526 lhs,
527 rhs,
528 out,
529 &Default::default(),
530 dtypes,
531 ) {
532 match err {
533 MatmulSetupError::Unavailable(_) => {
534 layered::launch_ref::<R, SimpleUnitAlgorithm>(
535 client,
536 lhs,
537 rhs,
538 out,
539 &Default::default(),
540 dtypes,
541 )
542 .unwrap();
543 }
544 _ => panic!("{err:?}"),
545 }
546 }
547
548 Ok(())
549 }
550 Strategy::SimpleVecMat(selection) => layered::launch_ref::<R, SimpleVecMatAlgorithm>(
551 client, lhs, rhs, out, selection, dtypes,
552 ),
553 Strategy::DoubleVecMat(selection) => layered::launch_ref::<R, DoubleVecMatAlgorithm>(
554 client, lhs, rhs, out, selection, dtypes,
555 ),
556 }
557}