cubecl_matmul/
base.rs

1use cubecl_common::quant::scheme::{QuantScheme, QuantStore, QuantValue};
2use cubecl_core::{
3    Runtime,
4    client::ComputeClient,
5    ir::StorageType,
6    prelude::{CubePrimitive, TensorHandleRef},
7};
8
9use cubecl_std::tensor::{TensorHandle, into_contiguous_packed, into_contiguous_pitched};
10use serde::{Deserialize, Serialize};
11
12use crate::{
13    components::{
14        MatmulElems, MatmulSetupError,
15        tile::{cmma::CmmaMatmul, io::Filled, mma::MmaMatmul},
16    },
17    kernels::layered::{
18        Selection,
19        double_buffering::{DoubleBufferingArgs, TmaDoubleBufferingAlgorithm},
20        double_unit::{DoubleUnitAlgorithm, DoubleUnitSelectionArgs},
21        ordered_double_buffering::OrderedSelectionArgs,
22        simple::SimpleArgs,
23        simple_unit::SimpleUnitSelectionArgs,
24        specialized::TmaSpecializedAlgorithm,
25        vecmat::{DoubleVecMatAlgorithm, SimpleVecMatAlgorithm},
26    },
27};
28
29use super::{
30    components::{
31        global::read::{
32            async_full_cooperative, async_full_cyclic, async_full_maximize_slice_length,
33            async_full_maximize_unit_count, sync_full_strided, sync_full_tilewise,
34        },
35        stage::{ColMajorTilingOrder, RowMajorTilingOrder},
36    },
37    kernels::{
38        layered::{
39            self,
40            double_buffering::{
41                CyclicDoubleBufferingAlgorithm, HybridDoubleBufferingAlgorithm,
42                TilewiseDoubleBufferingAlgorithm,
43            },
44            ordered_double_buffering::OrderedDoubleBufferingAlgorithm,
45            simple::{SimpleAlgorithm, SimpleTmaAlgorithm},
46            simple_unit::SimpleUnitAlgorithm,
47        },
48        naive,
49    },
50};
51
52#[derive(Debug, Clone, Default)]
53/// The matmul algorithm to launch
54///
55/// Most strategies have a selection input that can be overwritten or inferred from minimal information
56/// Some strategies must have a specified loading strategy
57pub enum Strategy {
58    Simple {
59        read_strategy: ReadingStrategy,
60        selection: Selection<SimpleArgs>,
61        tile_kind: AcceleratedTileKind,
62    },
63    DoubleBuffering {
64        read_strategy: PartialReadingStrategy,
65        selection: Selection<DoubleBufferingArgs>,
66        tile_kind: AcceleratedTileKind,
67    },
68    Specialized {
69        selection: Selection<()>,
70        tile_kind: AcceleratedTileKind,
71    },
72    SimpleUnit(Selection<SimpleUnitSelectionArgs>),
73    DoubleUnit(Selection<DoubleUnitSelectionArgs>),
74    SimpleVecMat(Selection<()>),
75    DoubleVecMat(Selection<()>),
76    OrderedDoubleBuffering {
77        selection: Selection<OrderedSelectionArgs>,
78        tile_kind: AcceleratedTileKind,
79    },
80    Naive,
81    #[default]
82    /// Tries using a Simple matmul, then a SimpleUnit if the former failed
83    Auto,
84}
85
86#[derive(Debug, Clone, Copy)]
87/// Which reader to use in simple algorithms
88pub enum ReadingStrategy {
89    Cyclic,
90    Strided,
91    Tilewise,
92    AsyncCooperative,
93    AsyncCyclic,
94    AsyncMaximizeSliceLength,
95    AsyncMaximizeUnitCount,
96    Tma,
97}
98
99#[derive(Debug, Clone, Copy)]
100/// Which reader to use in double buffering algorithms
101pub enum PartialReadingStrategy {
102    Cyclic,
103    Tilewise,
104    Hybrid,
105    Tma,
106}
107
108#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
109/// Which tile matmul to use for accelerated algorithms
110pub enum AcceleratedTileKind {
111    #[default]
112    Cmma,
113    Mma,
114}
115
116macro_rules! with_tile_kind {
117    ($kind: expr, $T: ident, $launch: expr) => {
118        match $kind {
119            AcceleratedTileKind::Cmma => {
120                type $T = CmmaMatmul<Filled>;
121                ($launch)()
122            }
123            AcceleratedTileKind::Mma => {
124                type $T = MmaMatmul;
125                ($launch)()
126            }
127        }
128    };
129}
130
131pub enum MatmulInputHandle<R: Runtime> {
132    Normal(TensorHandle<R>),
133    Quantized {
134        data: TensorHandle<R>,
135        scale: TensorHandle<R>,
136        shape: Vec<usize>,
137        scheme: QuantScheme,
138    },
139}
140
141impl<R: Runtime> MatmulInputHandle<R> {
142    pub fn as_ref(&self) -> MatmulInputHandleRef<'_, R> {
143        match self {
144            MatmulInputHandle::Normal(handle) => {
145                MatmulInputHandleRef::Normal(handle.as_ref(), handle.dtype)
146            }
147            MatmulInputHandle::Quantized {
148                data,
149                scale,
150                shape,
151                scheme,
152            } => MatmulInputHandleRef::Quantized {
153                data: data.as_ref(),
154                scale: scale.as_ref(),
155                data_dtype: data.dtype,
156                scale_dtype: scale.dtype,
157                shape,
158                scheme,
159            },
160        }
161    }
162
163    pub fn from_ref(handle: &MatmulInputHandleRef<'_, R>) -> Self {
164        match handle {
165            MatmulInputHandleRef::Normal(handle, dtype) => {
166                MatmulInputHandle::Normal(TensorHandle::from_ref(handle, *dtype))
167            }
168            MatmulInputHandleRef::Quantized {
169                data,
170                scale,
171                shape,
172                scheme,
173                data_dtype,
174                scale_dtype,
175            } => MatmulInputHandle::Quantized {
176                data: TensorHandle::from_ref(data, *data_dtype),
177                scale: TensorHandle::from_ref(scale, *scale_dtype),
178                shape: shape.to_vec(),
179                scheme: **scheme,
180            },
181        }
182    }
183
184    pub fn data(&self) -> &TensorHandle<R> {
185        match self {
186            MatmulInputHandle::Normal(handle) => handle,
187            MatmulInputHandle::Quantized { data, .. } => data,
188        }
189    }
190
191    pub fn swap_dims(&mut self, dim0: usize, dim1: usize) {
192        match self {
193            MatmulInputHandle::Normal(handle) => {
194                handle.shape.swap(dim0, dim1);
195                handle.strides.swap(dim0, dim1);
196            }
197            MatmulInputHandle::Quantized {
198                data, scale, shape, ..
199            } => {
200                data.shape.swap(dim0, dim1);
201                data.strides.swap(dim0, dim1);
202                if scale.shape.len() == data.shape.len() {
203                    scale.shape.swap(dim0, dim1);
204                    scale.strides.swap(dim0, dim1);
205                }
206                shape.swap(dim0, dim1);
207            }
208        }
209    }
210}
211
212impl<R: Runtime> Clone for MatmulInputHandle<R> {
213    fn clone(&self) -> Self {
214        match self {
215            Self::Normal(handle) => Self::Normal(handle.clone()),
216            Self::Quantized {
217                data,
218                scale,
219                shape,
220                scheme,
221            } => Self::Quantized {
222                data: data.clone(),
223                scale: scale.clone(),
224                shape: shape.clone(),
225                scheme: *scheme,
226            },
227        }
228    }
229}
230
231#[derive(Debug)]
232pub enum MatmulInputHandleRef<'a, R: Runtime> {
233    Normal(TensorHandleRef<'a, R>, StorageType),
234    Quantized {
235        data: TensorHandleRef<'a, R>,
236        data_dtype: StorageType,
237        scale: TensorHandleRef<'a, R>,
238        scale_dtype: StorageType,
239        /// Unpacked shape, excluding padding
240        shape: &'a [usize],
241        scheme: &'a QuantScheme,
242    },
243}
244
245impl<'a, R: Runtime> Clone for MatmulInputHandleRef<'a, R> {
246    fn clone(&self) -> Self {
247        *self
248    }
249}
250
251impl<'a, R: Runtime> Copy for MatmulInputHandleRef<'a, R> {}
252
253impl<'a, R: Runtime> MatmulInputHandleRef<'a, R> {
254    pub fn new(data: TensorHandleRef<'a, R>, dtype: StorageType) -> Self {
255        Self::Normal(data, dtype)
256    }
257
258    pub fn quantized(
259        data: TensorHandleRef<'a, R>,
260        scale: TensorHandleRef<'a, R>,
261        shape: &'a [usize],
262        scheme: &'a QuantScheme,
263        data_dtype: StorageType,
264        scale_dtype: StorageType,
265    ) -> Self {
266        Self::Quantized {
267            data,
268            scale,
269            shape,
270            scheme,
271            data_dtype,
272            scale_dtype,
273        }
274    }
275
276    pub fn data(&self) -> &TensorHandleRef<'a, R> {
277        match self {
278            MatmulInputHandleRef::Normal(handle, ..) => handle,
279            MatmulInputHandleRef::Quantized { data, .. } => data,
280        }
281    }
282
283    pub fn data_mut(&mut self) -> &mut TensorHandleRef<'a, R> {
284        match self {
285            MatmulInputHandleRef::Normal(handle, ..) => handle,
286            MatmulInputHandleRef::Quantized { data, .. } => data,
287        }
288    }
289
290    pub fn scale(&self) -> Option<&TensorHandleRef<'a, R>> {
291        match self {
292            MatmulInputHandleRef::Normal(..) => None,
293            MatmulInputHandleRef::Quantized { scale, .. } => Some(scale),
294        }
295    }
296
297    pub fn scheme(&self) -> Option<&QuantScheme> {
298        match self {
299            MatmulInputHandleRef::Normal(..) => None,
300            MatmulInputHandleRef::Quantized { scheme, .. } => Some(scheme),
301        }
302    }
303
304    pub fn shape(&self) -> &[usize] {
305        match self {
306            MatmulInputHandleRef::Normal(handle, ..) => handle.shape,
307            MatmulInputHandleRef::Quantized { shape, .. } => shape,
308        }
309    }
310
311    pub fn into_contiguous(&self, client: &ComputeClient<R::Server>) -> MatmulInputHandle<R> {
312        match self {
313            MatmulInputHandleRef::Normal(data, dtype) => {
314                MatmulInputHandle::Normal(into_contiguous_pitched::<R>(client, data, *dtype))
315            }
316            MatmulInputHandleRef::Quantized {
317                data,
318                scale,
319                shape,
320                scheme,
321                data_dtype,
322                scale_dtype,
323            } => {
324                let data = match scheme.store {
325                    // e2m1 has native packing (e2m1x2) so also needs to be re-packed
326                    QuantStore::Native if scheme.value == QuantValue::E2M1 => {
327                        let data = into_contiguous_packed::<R>(
328                            client,
329                            data,
330                            shape,
331                            2,
332                            u8::as_type_native_unchecked(),
333                        );
334                        // Unsafely cast to E
335                        TensorHandle::from_ref(&data.as_ref(), *data_dtype)
336                    }
337                    QuantStore::U32 => {
338                        let data = into_contiguous_packed::<R>(
339                            client,
340                            data,
341                            shape,
342                            scheme.num_quants() as u32,
343                            u32::as_type_native_unchecked(),
344                        );
345                        // Unsafely cast to E
346                        TensorHandle::from_ref(&data.as_ref(), *data_dtype)
347                    }
348                    _ => into_contiguous_pitched::<R>(client, data, *data_dtype),
349                };
350                MatmulInputHandle::Quantized {
351                    data,
352                    scale: TensorHandle::from_ref(scale, *scale_dtype),
353                    shape: shape.to_vec(),
354                    scheme: **scheme,
355                }
356            }
357        }
358    }
359}
360
361#[allow(clippy::result_large_err)]
362pub fn launch<R: Runtime>(
363    strategy: &Strategy,
364    client: &ComputeClient<R::Server>,
365    lhs: MatmulInputHandle<R>,
366    rhs: MatmulInputHandle<R>,
367    out: TensorHandle<R>,
368    mut dtypes: MatmulElems,
369) -> Result<(), MatmulSetupError> {
370    launch_ref::<R>(
371        strategy,
372        client,
373        &lhs.as_ref(),
374        &rhs.as_ref(),
375        &out.as_ref(),
376        &mut dtypes,
377    )
378}
379
380#[allow(clippy::result_large_err)]
381/// Launches a matrix multiplication kernel..
382///
383/// # Notes
384///
385/// The matmul elements may get changed during selection for improved performance when
386/// the hardware supports it.
387/// Only the inner element types may change such as the stage or register element types.
388pub fn launch_ref<R: Runtime>(
389    strategy: &Strategy,
390    client: &ComputeClient<R::Server>,
391    lhs: &MatmulInputHandleRef<R>,
392    rhs: &MatmulInputHandleRef<R>,
393    out: &TensorHandleRef<R>,
394    dtypes: &mut MatmulElems,
395) -> Result<(), MatmulSetupError> {
396    match strategy {
397        Strategy::Simple {
398            read_strategy,
399            selection,
400            tile_kind,
401        } => with_tile_kind!(tile_kind, Accelerated, || match read_strategy {
402            ReadingStrategy::Cyclic => {
403                layered::launch_ref::<R, SimpleAlgorithm<Accelerated>>(
404                    client, lhs, rhs, out, selection, dtypes,
405                )
406            }
407            ReadingStrategy::Strided => layered::launch_ref::<
408                R,
409                SimpleAlgorithm<
410                    Accelerated,
411                    sync_full_strided::SyncFullStridedLoading,
412                    sync_full_strided::SyncFullStridedLoading,
413                >,
414            >(client, lhs, rhs, out, selection, dtypes),
415            ReadingStrategy::Tilewise => {
416                layered::launch_ref::<
417                    R,
418                    SimpleAlgorithm<
419                        Accelerated,
420                        sync_full_tilewise::SyncFullTilewiseLoading<ColMajorTilingOrder>,
421                        sync_full_tilewise::SyncFullTilewiseLoading<RowMajorTilingOrder>,
422                    >,
423                >(client, lhs, rhs, out, selection, dtypes)
424            }
425            ReadingStrategy::AsyncCooperative => {
426                layered::launch_ref::<
427                    R,
428                    SimpleAlgorithm<
429                        Accelerated,
430                        async_full_cooperative::AsyncFullCooperativeLoading,
431                        async_full_cooperative::AsyncFullCooperativeLoading,
432                    >,
433                >(client, lhs, rhs, out, selection, dtypes)
434            }
435            ReadingStrategy::AsyncCyclic => {
436                layered::launch_ref::<
437                    R,
438                    SimpleAlgorithm<
439                        Accelerated,
440                        async_full_cyclic::AsyncFullCyclicLoading<ColMajorTilingOrder>,
441                        async_full_cyclic::AsyncFullCyclicLoading<RowMajorTilingOrder>,
442                    >,
443                >(client, lhs, rhs, out, selection, dtypes)
444            }
445            ReadingStrategy::AsyncMaximizeSliceLength => {
446                layered::launch_ref::<
447                    R,
448                    SimpleAlgorithm<
449                        Accelerated,
450                        async_full_maximize_slice_length::AsyncFullMaximizeSliceLengthLoading,
451                        async_full_maximize_slice_length::AsyncFullMaximizeSliceLengthLoading,
452                    >,
453                >(client, lhs, rhs, out, &Default::default(), dtypes)
454            }
455            ReadingStrategy::AsyncMaximizeUnitCount => {
456                layered::launch_ref::<
457                    R,
458                    SimpleAlgorithm<
459                        Accelerated,
460                        async_full_maximize_unit_count::AsyncFullMaximizeUnitCountLoading,
461                        async_full_maximize_unit_count::AsyncFullMaximizeUnitCountLoading,
462                    >,
463                >(client, lhs, rhs, out, &Default::default(), dtypes)
464            }
465            ReadingStrategy::Tma => layered::launch_ref_tma::<R, SimpleTmaAlgorithm<Accelerated>>(
466                client, lhs, rhs, out, selection, dtypes
467            ),
468        }),
469        Strategy::DoubleBuffering {
470            read_strategy,
471            selection,
472            tile_kind,
473        } => with_tile_kind!(tile_kind, Accelerated, || match read_strategy {
474            PartialReadingStrategy::Cyclic => {
475                layered::launch_ref::<R, CyclicDoubleBufferingAlgorithm<Accelerated>>(
476                    client, lhs, rhs, out, selection, dtypes,
477                )
478            }
479            PartialReadingStrategy::Tilewise => {
480                layered::launch_ref::<R, TilewiseDoubleBufferingAlgorithm<Accelerated>>(
481                    client, lhs, rhs, out, selection, dtypes,
482                )
483            }
484            PartialReadingStrategy::Hybrid => {
485                layered::launch_ref::<R, HybridDoubleBufferingAlgorithm<Accelerated>>(
486                    client, lhs, rhs, out, selection, dtypes,
487                )
488            }
489            PartialReadingStrategy::Tma => {
490                layered::launch_ref_tma::<R, TmaDoubleBufferingAlgorithm<Accelerated>>(
491                    client, lhs, rhs, out, selection, dtypes,
492                )
493            }
494        }),
495        Strategy::Specialized {
496            selection,
497            tile_kind,
498        } => with_tile_kind!(tile_kind, Accelerated, || layered::launch_ref_tma::<
499            R,
500            TmaSpecializedAlgorithm<Accelerated>,
501        >(
502            client, lhs, rhs, out, selection, dtypes
503        )),
504        Strategy::OrderedDoubleBuffering {
505            selection,
506            tile_kind,
507        } => with_tile_kind!(tile_kind, Accelerated, || layered::launch_ref::<
508            R,
509            OrderedDoubleBufferingAlgorithm<Accelerated>,
510        >(
511            client, lhs, rhs, out, selection, dtypes
512        )),
513        Strategy::SimpleUnit(selection) => {
514            layered::launch_ref::<R, SimpleUnitAlgorithm>(client, lhs, rhs, out, selection, dtypes)
515        }
516        Strategy::DoubleUnit(selection) => {
517            layered::launch_ref::<R, DoubleUnitAlgorithm>(client, lhs, rhs, out, selection, dtypes)
518        }
519        Strategy::Naive => {
520            naive::launch_ref::<R>(client, lhs, rhs, out, dtypes)?;
521            Ok(())
522        }
523        Strategy::Auto => {
524            if let Err(err) = layered::launch_ref::<R, SimpleAlgorithm<CmmaMatmul<Filled>>>(
525                client,
526                lhs,
527                rhs,
528                out,
529                &Default::default(),
530                dtypes,
531            ) {
532                match err {
533                    MatmulSetupError::Unavailable(_) => {
534                        layered::launch_ref::<R, SimpleUnitAlgorithm>(
535                            client,
536                            lhs,
537                            rhs,
538                            out,
539                            &Default::default(),
540                            dtypes,
541                        )
542                        .unwrap();
543                    }
544                    _ => panic!("{err:?}"),
545                }
546            }
547
548            Ok(())
549        }
550        Strategy::SimpleVecMat(selection) => layered::launch_ref::<R, SimpleVecMatAlgorithm>(
551            client, lhs, rhs, out, selection, dtypes,
552        ),
553        Strategy::DoubleVecMat(selection) => layered::launch_ref::<R, DoubleVecMatAlgorithm>(
554            client, lhs, rhs, out, selection, dtypes,
555        ),
556    }
557}