1use std::fmt::Display;
2
3use cubecl_common::quant::scheme::{QuantScheme, QuantStore, QuantValue};
4use cubecl_core::{
5 Runtime,
6 client::ComputeClient,
7 ir::StorageType,
8 prelude::{CubePrimitive, TensorHandleRef},
9 server::LaunchError,
10};
11
12use cubecl_std::tensor::{TensorHandle, into_contiguous_packed, into_contiguous_pitched};
13use serde::{Deserialize, Serialize};
14
15use crate::{
16 components::{
17 MatmulElems, MatmulSetupError,
18 global::read::{
19 async_partial_cyclic::AsyncPartialCyclicLoading,
20 async_partial_strided::AsyncPartialStridedLoading,
21 },
22 tile::{cmma::CmmaMatmul, io::Filled, mma::MmaMatmul},
23 },
24 kernels::layered::{
25 Selection,
26 double_buffering::*,
27 double_unit::{DoubleUnitAlgorithm, DoubleUnitSelectionArgs},
28 ordered_double_buffering::OrderedSelectionArgs,
29 simple::SimpleArgs,
30 simple_unit::SimpleUnitSelectionArgs,
31 specialized::SpecializedAlgorithm,
32 vecmat::{DoubleVecMatAlgorithm, SimpleVecMatAlgorithm},
33 },
34};
35
36use super::{
37 components::{
38 global::read::{
39 async_full_cooperative, async_full_cyclic, sync_full_strided, sync_full_tilewise,
40 },
41 stage::{ColMajorTilingOrder, RowMajorTilingOrder},
42 },
43 kernels::{
44 layered::{
45 self,
46 double_buffering::{
47 CyclicDoubleBufferingAlgorithm, HybridDoubleBufferingAlgorithm,
48 TilewiseDoubleBufferingAlgorithm,
49 },
50 ordered_double_buffering::OrderedDoubleBufferingAlgorithm,
51 simple::{SimpleAlgorithm, SimpleTmaAlgorithm},
52 simple_unit::SimpleUnitAlgorithm,
53 },
54 naive,
55 },
56};
57
58#[derive(Debug, Clone, Default)]
59pub enum Strategy {
64 Simple {
65 read_strategy: ReadingStrategy,
66 selection: Selection<SimpleArgs>,
67 tile_kind: AcceleratedTileKind,
68 },
69 DoubleBuffering {
70 read_strategy: PartialReadingStrategy,
71 selection: Selection<DoubleBufferingArgs>,
72 tile_kind: AcceleratedTileKind,
73 },
74 Specialized {
75 read_strategy: AsyncPartialReadingStrategy,
76 selection: Selection<()>,
77 tile_kind: AcceleratedTileKind,
78 },
79 SimpleUnit(Selection<SimpleUnitSelectionArgs>),
80 DoubleUnit(Selection<DoubleUnitSelectionArgs>),
81 SimpleVecMat(Selection<()>),
82 DoubleVecMat(Selection<()>),
83 OrderedDoubleBuffering {
84 selection: Selection<OrderedSelectionArgs>,
85 tile_kind: AcceleratedTileKind,
86 },
87 Naive,
88 #[default]
89 Auto,
91}
92
93impl Display for Strategy {
94 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95 match self {
96 Strategy::Simple {
97 read_strategy,
98 selection,
99 tile_kind,
100 } => {
101 f.write_fmt(format_args!("matmul_simple_{read_strategy}_{tile_kind}"))?;
102
103 match selection {
104 Selection::Forced(_) => f.write_str("_forced_selection")?,
105 Selection::Inferred(args) => {
106 if args.multi_rows {
107 f.write_str("_multirows")?;
108 }
109 }
110 };
111 }
112 Strategy::DoubleBuffering {
113 read_strategy,
114 selection,
115 tile_kind,
116 } => {
117 f.write_fmt(format_args!(
118 "matmul_double_buffering_{read_strategy}_{tile_kind}"
119 ))?;
120
121 match selection {
122 Selection::Forced(_) => f.write_str("_forced_selection")?,
123 Selection::Inferred(args) => {
124 if args.specialized {
125 f.write_str("_specialized")?;
126 }
127 }
128 };
129 }
130 Strategy::Specialized {
131 read_strategy,
132 selection,
133 tile_kind,
134 } => {
135 f.write_fmt(format_args!(
136 "matmul_specialized_{read_strategy}_{tile_kind}"
137 ))?;
138
139 match selection {
140 Selection::Forced(_) => f.write_str("_forced_selection")?,
141 Selection::Inferred(_) => {}
142 };
143 }
144 Strategy::SimpleUnit(selection) => {
145 f.write_fmt(format_args!("matmul_simple_unit"))?;
146
147 match selection {
148 Selection::Forced(_) => f.write_str("_forced_selection")?,
149 Selection::Inferred(args) => {
150 f.write_fmt(format_args!("_{}", args.tile_size))?;
151 }
152 };
153 }
154 Strategy::DoubleUnit(selection) => {
155 f.write_str("matmul_double_buffering_unit")?;
156
157 match selection {
158 Selection::Forced(_) => f.write_str("_forced_selection")?,
159 Selection::Inferred(args) => {
160 f.write_fmt(format_args!("_{}", args.tile_size))?;
161 }
162 };
163 }
164 Strategy::SimpleVecMat(selection) => {
165 f.write_str("vecmat_simple")?;
166
167 match selection {
168 Selection::Forced(_) => f.write_str("_forced_selection")?,
169 Selection::Inferred(_) => {}
170 };
171 }
172 Strategy::DoubleVecMat(selection) => {
173 f.write_str("vecmat_double_buffering")?;
174
175 match selection {
176 Selection::Forced(_) => f.write_str("_forced_selection")?,
177 Selection::Inferred(_) => {}
178 };
179 }
180 Strategy::OrderedDoubleBuffering {
181 selection,
182 tile_kind,
183 } => {
184 f.write_fmt(format_args!("matmul_double_buffering_ordered_{tile_kind}"))?;
185
186 match selection {
187 Selection::Forced(_) => f.write_str("_forced_selection")?,
188 Selection::Inferred(args) => {
189 if let Some(k) = args.partition_k {
190 f.write_fmt(format_args!("_partition_k{}", k))?;
191 }
192 if let Some(r) = args.row_count {
193 f.write_fmt(format_args!("_row_count{}", r))?;
194 }
195 if let Some(r) = args.rows_per_plane {
196 f.write_fmt(format_args!("_row_per_plane{}", r))?;
197 }
198 }
199 };
200 }
201 Strategy::Naive => f.write_str("matmul_naive")?,
202 Strategy::Auto => f.write_str("matmul_auto")?,
203 };
204
205 Ok(())
206 }
207}
208
209#[derive(Debug, Clone, Copy)]
210pub enum ReadingStrategy {
212 Cyclic,
213 Strided,
214 Tilewise,
215 AsyncCooperative,
216 AsyncCyclic,
217 Tma,
218}
219
220#[derive(Debug, Clone, Copy)]
221pub enum PartialReadingStrategy {
223 Cyclic,
224 Tilewise,
225 Hybrid,
226 Tma,
227 AsyncCyclic,
228 AsyncStrided,
229}
230
231#[derive(Debug, Clone, Copy)]
232pub enum AsyncPartialReadingStrategy {
234 Cyclic,
235 Strided,
236 Tma,
237}
238
239#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
240pub enum AcceleratedTileKind {
242 #[default]
243 Cmma,
244 Mma,
245}
246
247impl Display for AcceleratedTileKind {
250 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
251 match self {
252 AcceleratedTileKind::Cmma => f.write_str("cmma"),
253 AcceleratedTileKind::Mma => f.write_str("mma"),
254 }
255 }
256}
257
258impl Display for ReadingStrategy {
259 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
260 match self {
261 ReadingStrategy::Cyclic => f.write_str("cyclic"),
262 ReadingStrategy::Strided => f.write_str("strided"),
263 ReadingStrategy::Tilewise => f.write_str("tilewise"),
264 ReadingStrategy::AsyncCooperative => f.write_str("async_cooperative"),
265 ReadingStrategy::AsyncCyclic => f.write_str("async_cyclic"),
266 ReadingStrategy::Tma => f.write_str("tma"),
267 }
268 }
269}
270
271impl Display for PartialReadingStrategy {
272 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
273 match self {
274 PartialReadingStrategy::Cyclic => f.write_str("cyclic"),
275 PartialReadingStrategy::Tilewise => f.write_str("tilewise"),
276 PartialReadingStrategy::Hybrid => f.write_str("hybrid"),
277 PartialReadingStrategy::Tma => f.write_str("tma"),
278 PartialReadingStrategy::AsyncCyclic => f.write_str("async_cyclic"),
279 PartialReadingStrategy::AsyncStrided => f.write_str("async_strided"),
280 }
281 }
282}
283
284impl Display for AsyncPartialReadingStrategy {
285 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
286 match self {
287 AsyncPartialReadingStrategy::Cyclic => f.write_str("cyclic"),
288 AsyncPartialReadingStrategy::Strided => f.write_str("strided"),
289 AsyncPartialReadingStrategy::Tma => f.write_str("tma"),
290 }
291 }
292}
293
294macro_rules! with_tile_kind {
295 ($kind: expr, $T: ident, $launch: expr) => {
296 match $kind {
297 AcceleratedTileKind::Cmma => {
298 type $T = CmmaMatmul<Filled>;
299 ($launch)()
300 }
301 AcceleratedTileKind::Mma => {
302 type $T = MmaMatmul;
303 ($launch)()
304 }
305 }
306 };
307}
308
309pub enum MatmulInputHandle<R: Runtime> {
310 Normal(TensorHandle<R>),
311 Quantized {
312 data: TensorHandle<R>,
313 scale: TensorHandle<R>,
314 shape: Vec<usize>,
315 scheme: QuantScheme,
316 },
317}
318
319impl<R: Runtime> MatmulInputHandle<R> {
320 pub fn as_ref(&self) -> MatmulInputHandleRef<'_, R> {
321 match self {
322 MatmulInputHandle::Normal(handle) => {
323 MatmulInputHandleRef::Normal(handle.as_ref(), handle.dtype)
324 }
325 MatmulInputHandle::Quantized {
326 data,
327 scale,
328 shape,
329 scheme,
330 } => MatmulInputHandleRef::Quantized {
331 data: data.as_ref(),
332 scale: scale.as_ref(),
333 data_dtype: data.dtype,
334 scale_dtype: scale.dtype,
335 shape,
336 scheme,
337 },
338 }
339 }
340
341 pub fn from_ref(handle: &MatmulInputHandleRef<'_, R>) -> Self {
342 match handle {
343 MatmulInputHandleRef::Normal(handle, dtype) => {
344 MatmulInputHandle::Normal(TensorHandle::from_ref(handle, *dtype))
345 }
346 MatmulInputHandleRef::Quantized {
347 data,
348 scale,
349 shape,
350 scheme,
351 data_dtype,
352 scale_dtype,
353 } => MatmulInputHandle::Quantized {
354 data: TensorHandle::from_ref(data, *data_dtype),
355 scale: TensorHandle::from_ref(scale, *scale_dtype),
356 shape: shape.to_vec(),
357 scheme: **scheme,
358 },
359 }
360 }
361
362 pub fn data(&self) -> &TensorHandle<R> {
363 match self {
364 MatmulInputHandle::Normal(handle) => handle,
365 MatmulInputHandle::Quantized { data, .. } => data,
366 }
367 }
368
369 pub fn swap_dims(&mut self, dim0: usize, dim1: usize) {
370 match self {
371 MatmulInputHandle::Normal(handle) => {
372 handle.shape.swap(dim0, dim1);
373 handle.strides.swap(dim0, dim1);
374 }
375 MatmulInputHandle::Quantized {
376 data, scale, shape, ..
377 } => {
378 data.shape.swap(dim0, dim1);
379 data.strides.swap(dim0, dim1);
380 if scale.shape.len() == data.shape.len() {
381 scale.shape.swap(dim0, dim1);
382 scale.strides.swap(dim0, dim1);
383 }
384 shape.swap(dim0, dim1);
385 }
386 }
387 }
388}
389
390impl<R: Runtime> Clone for MatmulInputHandle<R> {
391 fn clone(&self) -> Self {
392 match self {
393 Self::Normal(handle) => Self::Normal(handle.clone()),
394 Self::Quantized {
395 data,
396 scale,
397 shape,
398 scheme,
399 } => Self::Quantized {
400 data: data.clone(),
401 scale: scale.clone(),
402 shape: shape.clone(),
403 scheme: *scheme,
404 },
405 }
406 }
407}
408
409#[derive(Debug)]
410pub enum MatmulInputHandleRef<'a, R: Runtime> {
411 Normal(TensorHandleRef<'a, R>, StorageType),
412 Quantized {
413 data: TensorHandleRef<'a, R>,
414 data_dtype: StorageType,
415 scale: TensorHandleRef<'a, R>,
416 scale_dtype: StorageType,
417 shape: &'a [usize],
419 scheme: &'a QuantScheme,
420 },
421}
422
423impl<'a, R: Runtime> Clone for MatmulInputHandleRef<'a, R> {
424 fn clone(&self) -> Self {
425 *self
426 }
427}
428
429impl<'a, R: Runtime> Copy for MatmulInputHandleRef<'a, R> {}
430
431impl<'a, R: Runtime> MatmulInputHandleRef<'a, R> {
432 pub fn new(data: TensorHandleRef<'a, R>, dtype: StorageType) -> Self {
433 Self::Normal(data, dtype)
434 }
435
436 pub fn quantized(
437 data: TensorHandleRef<'a, R>,
438 scale: TensorHandleRef<'a, R>,
439 shape: &'a [usize],
440 scheme: &'a QuantScheme,
441 data_dtype: StorageType,
442 scale_dtype: StorageType,
443 ) -> Self {
444 Self::Quantized {
445 data,
446 scale,
447 shape,
448 scheme,
449 data_dtype,
450 scale_dtype,
451 }
452 }
453
454 pub fn data(&self) -> &TensorHandleRef<'a, R> {
455 match self {
456 MatmulInputHandleRef::Normal(handle, ..) => handle,
457 MatmulInputHandleRef::Quantized { data, .. } => data,
458 }
459 }
460
461 pub fn data_mut(&mut self) -> &mut TensorHandleRef<'a, R> {
462 match self {
463 MatmulInputHandleRef::Normal(handle, ..) => handle,
464 MatmulInputHandleRef::Quantized { data, .. } => data,
465 }
466 }
467
468 pub fn scale(&self) -> Option<&TensorHandleRef<'a, R>> {
469 match self {
470 MatmulInputHandleRef::Normal(..) => None,
471 MatmulInputHandleRef::Quantized { scale, .. } => Some(scale),
472 }
473 }
474
475 pub fn scheme(&self) -> Option<&QuantScheme> {
476 match self {
477 MatmulInputHandleRef::Normal(..) => None,
478 MatmulInputHandleRef::Quantized { scheme, .. } => Some(scheme),
479 }
480 }
481
482 pub fn shape(&self) -> &[usize] {
483 match self {
484 MatmulInputHandleRef::Normal(handle, ..) => handle.shape,
485 MatmulInputHandleRef::Quantized { shape, .. } => shape,
486 }
487 }
488
489 pub fn into_contiguous(
490 &self,
491 client: &ComputeClient<R>,
492 ) -> Result<MatmulInputHandle<R>, LaunchError> {
493 let val = match self {
494 MatmulInputHandleRef::Normal(data, dtype) => {
495 MatmulInputHandle::Normal(into_contiguous_pitched(client, data, *dtype)?)
496 }
497 MatmulInputHandleRef::Quantized {
498 data,
499 scale,
500 shape,
501 scheme,
502 data_dtype,
503 scale_dtype,
504 } => {
505 let data = match scheme.store {
506 QuantStore::Native if scheme.value == QuantValue::E2M1 => {
508 let data = into_contiguous_packed(
509 client,
510 data,
511 shape,
512 2,
513 u8::as_type_native_unchecked(),
514 )?;
515 TensorHandle::from_ref(&data.as_ref(), *data_dtype)
517 }
518 QuantStore::U32 => {
519 let data = into_contiguous_packed(
520 client,
521 data,
522 shape,
523 scheme.num_quants() as u32,
524 u32::as_type_native_unchecked(),
525 )?;
526 TensorHandle::from_ref(&data.as_ref(), *data_dtype)
528 }
529 _ => into_contiguous_pitched(client, data, *data_dtype)?,
530 };
531 MatmulInputHandle::Quantized {
532 data,
533 scale: TensorHandle::from_ref(scale, *scale_dtype),
534 shape: shape.to_vec(),
535 scheme: **scheme,
536 }
537 }
538 };
539
540 Ok(val)
541 }
542}
543
544#[allow(clippy::result_large_err)]
545pub fn launch<R: Runtime>(
546 strategy: &Strategy,
547 client: &ComputeClient<R>,
548 lhs: MatmulInputHandle<R>,
549 rhs: MatmulInputHandle<R>,
550 out: TensorHandle<R>,
551 mut dtypes: MatmulElems,
552) -> Result<(), MatmulSetupError> {
553 launch_ref(
554 strategy,
555 client,
556 &lhs.as_ref(),
557 &rhs.as_ref(),
558 &out.as_ref(),
559 &mut dtypes,
560 )
561}
562
563#[allow(clippy::result_large_err)]
564pub fn launch_ref<R: Runtime>(
572 strategy: &Strategy,
573 client: &ComputeClient<R>,
574 lhs: &MatmulInputHandleRef<R>,
575 rhs: &MatmulInputHandleRef<R>,
576 out: &TensorHandleRef<R>,
577 dtypes: &mut MatmulElems,
578) -> Result<(), MatmulSetupError> {
579 match strategy {
580 Strategy::Simple {
581 read_strategy,
582 selection,
583 tile_kind,
584 } => with_tile_kind!(tile_kind, Accelerated, || match read_strategy {
585 ReadingStrategy::Cyclic => {
586 layered::launch_ref::<R, SimpleAlgorithm<Accelerated>>(
587 client, lhs, rhs, out, selection, dtypes,
588 )
589 }
590 ReadingStrategy::Strided => layered::launch_ref::<
591 R,
592 SimpleAlgorithm<
593 Accelerated,
594 sync_full_strided::SyncFullStridedLoading,
595 sync_full_strided::SyncFullStridedLoading,
596 >,
597 >(client, lhs, rhs, out, selection, dtypes),
598 ReadingStrategy::Tilewise => {
599 layered::launch_ref::<
600 R,
601 SimpleAlgorithm<
602 Accelerated,
603 sync_full_tilewise::SyncFullTilewiseLoading<ColMajorTilingOrder>,
604 sync_full_tilewise::SyncFullTilewiseLoading<RowMajorTilingOrder>,
605 >,
606 >(client, lhs, rhs, out, selection, dtypes)
607 }
608 ReadingStrategy::AsyncCooperative => {
609 layered::launch_ref::<
610 R,
611 SimpleAlgorithm<
612 Accelerated,
613 async_full_cooperative::AsyncFullCooperativeLoading,
614 async_full_cooperative::AsyncFullCooperativeLoading,
615 >,
616 >(client, lhs, rhs, out, selection, dtypes)
617 }
618 ReadingStrategy::AsyncCyclic => {
619 layered::launch_ref::<
620 R,
621 SimpleAlgorithm<
622 Accelerated,
623 async_full_cyclic::AsyncFullCyclicLoading<ColMajorTilingOrder>,
624 async_full_cyclic::AsyncFullCyclicLoading<RowMajorTilingOrder>,
625 >,
626 >(client, lhs, rhs, out, selection, dtypes)
627 }
628 ReadingStrategy::Tma => layered::launch_ref_tma::<R, SimpleTmaAlgorithm<Accelerated>>(
629 client, lhs, rhs, out, selection, dtypes
630 ),
631 }),
632 Strategy::DoubleBuffering {
633 read_strategy,
634 selection,
635 tile_kind,
636 } => with_tile_kind!(tile_kind, Accelerated, || match read_strategy {
637 PartialReadingStrategy::Cyclic => {
638 layered::launch_ref::<R, CyclicDoubleBufferingAlgorithm<Accelerated>>(
639 client, lhs, rhs, out, selection, dtypes,
640 )
641 }
642 PartialReadingStrategy::Tilewise => {
643 layered::launch_ref::<R, TilewiseDoubleBufferingAlgorithm<Accelerated>>(
644 client, lhs, rhs, out, selection, dtypes,
645 )
646 }
647 PartialReadingStrategy::Hybrid => {
648 layered::launch_ref::<R, HybridDoubleBufferingAlgorithm<Accelerated>>(
649 client, lhs, rhs, out, selection, dtypes,
650 )
651 }
652 PartialReadingStrategy::Tma => {
653 layered::launch_ref_tma::<R, TmaDoubleBufferingAlgorithm<Accelerated>>(
654 client, lhs, rhs, out, selection, dtypes,
655 )
656 }
657 PartialReadingStrategy::AsyncCyclic => {
658 layered::launch_ref::<R, AsyncCyclicDoubleBufferingAlgorithm<Accelerated>>(
659 client, lhs, rhs, out, selection, dtypes,
660 )
661 }
662 PartialReadingStrategy::AsyncStrided => {
663 layered::launch_ref::<R, AsyncStridedDoubleBufferingAlgorithm<Accelerated>>(
664 client, lhs, rhs, out, selection, dtypes,
665 )
666 }
667 }),
668 Strategy::Specialized {
669 read_strategy,
670 selection,
671 tile_kind,
672 } => with_tile_kind!(tile_kind, Accelerated, || match read_strategy {
673 AsyncPartialReadingStrategy::Cyclic => layered::launch_ref::<
674 R,
675 SpecializedAlgorithm<Accelerated, AsyncPartialCyclicLoading<ColMajorTilingOrder>>,
676 >(
677 client, lhs, rhs, out, selection, dtypes
678 ),
679 AsyncPartialReadingStrategy::Strided =>
680 layered::launch_ref::<
681 R,
682 SpecializedAlgorithm<Accelerated, AsyncPartialStridedLoading>,
683 >(client, lhs, rhs, out, selection, dtypes),
684 AsyncPartialReadingStrategy::Tma =>
685 layered::launch_ref_tma::<R, SpecializedAlgorithm<Accelerated>>(
686 client, lhs, rhs, out, selection, dtypes
687 ),
688 }),
689 Strategy::OrderedDoubleBuffering {
690 selection,
691 tile_kind,
692 } => with_tile_kind!(tile_kind, Accelerated, || layered::launch_ref::<
693 R,
694 OrderedDoubleBufferingAlgorithm<Accelerated>,
695 >(
696 client, lhs, rhs, out, selection, dtypes
697 )),
698 Strategy::SimpleUnit(selection) => {
699 layered::launch_ref::<R, SimpleUnitAlgorithm>(client, lhs, rhs, out, selection, dtypes)
700 }
701 Strategy::DoubleUnit(selection) => {
702 layered::launch_ref::<R, DoubleUnitAlgorithm>(client, lhs, rhs, out, selection, dtypes)
703 }
704 Strategy::Naive => {
705 naive::launch_ref(client, lhs, rhs, out, dtypes)?;
706 Ok(())
707 }
708 Strategy::Auto => {
709 if let Err(err) = layered::launch_ref::<R, SimpleAlgorithm<CmmaMatmul<Filled>>>(
710 client,
711 lhs,
712 rhs,
713 out,
714 &Default::default(),
715 dtypes,
716 ) {
717 match err {
718 MatmulSetupError::Unavailable(_) => {
719 layered::launch_ref::<R, SimpleUnitAlgorithm>(
720 client,
721 lhs,
722 rhs,
723 out,
724 &Default::default(),
725 dtypes,
726 )
727 .unwrap();
728 }
729 _ => panic!("{err:?}"),
730 }
731 }
732
733 Ok(())
734 }
735 Strategy::SimpleVecMat(selection) => layered::launch_ref::<R, SimpleVecMatAlgorithm>(
736 client, lhs, rhs, out, selection, dtypes,
737 ),
738 Strategy::DoubleVecMat(selection) => layered::launch_ref::<R, DoubleVecMatAlgorithm>(
739 client, lhs, rhs, out, selection, dtypes,
740 ),
741 }
742}