cubecl_matmul/
base.rs

1use std::fmt::Display;
2
3use cubecl_common::quant::scheme::{QuantScheme, QuantStore, QuantValue};
4use cubecl_core::{
5    Runtime,
6    client::ComputeClient,
7    ir::StorageType,
8    prelude::{CubePrimitive, TensorHandleRef},
9    server::LaunchError,
10};
11
12use cubecl_std::tensor::{TensorHandle, into_contiguous_packed, into_contiguous_pitched};
13use serde::{Deserialize, Serialize};
14
15use crate::{
16    components::{
17        MatmulElems, MatmulSetupError,
18        global::read::{
19            async_partial_cyclic::AsyncPartialCyclicLoading,
20            async_partial_strided::AsyncPartialStridedLoading,
21        },
22        tile::{cmma::CmmaMatmul, io::Filled, mma::MmaMatmul},
23    },
24    kernels::layered::{
25        Selection,
26        double_buffering::*,
27        double_unit::{DoubleUnitAlgorithm, DoubleUnitSelectionArgs},
28        ordered_double_buffering::OrderedSelectionArgs,
29        simple::SimpleArgs,
30        simple_unit::SimpleUnitSelectionArgs,
31        specialized::SpecializedAlgorithm,
32        vecmat::{DoubleVecMatAlgorithm, SimpleVecMatAlgorithm},
33    },
34};
35
36use super::{
37    components::{
38        global::read::{
39            async_full_cooperative, async_full_cyclic, sync_full_strided, sync_full_tilewise,
40        },
41        stage::{ColMajorTilingOrder, RowMajorTilingOrder},
42    },
43    kernels::{
44        layered::{
45            self,
46            double_buffering::{
47                CyclicDoubleBufferingAlgorithm, HybridDoubleBufferingAlgorithm,
48                TilewiseDoubleBufferingAlgorithm,
49            },
50            ordered_double_buffering::OrderedDoubleBufferingAlgorithm,
51            simple::{SimpleAlgorithm, SimpleTmaAlgorithm},
52            simple_unit::SimpleUnitAlgorithm,
53        },
54        naive,
55    },
56};
57
58#[derive(Debug, Clone, Default)]
59/// The matmul algorithm to launch
60///
61/// Most strategies have a selection input that can be overwritten or inferred from minimal information
62/// Some strategies must have a specified loading strategy
63pub enum Strategy {
64    Simple {
65        read_strategy: ReadingStrategy,
66        selection: Selection<SimpleArgs>,
67        tile_kind: AcceleratedTileKind,
68    },
69    DoubleBuffering {
70        read_strategy: PartialReadingStrategy,
71        selection: Selection<DoubleBufferingArgs>,
72        tile_kind: AcceleratedTileKind,
73    },
74    Specialized {
75        read_strategy: AsyncPartialReadingStrategy,
76        selection: Selection<()>,
77        tile_kind: AcceleratedTileKind,
78    },
79    SimpleUnit(Selection<SimpleUnitSelectionArgs>),
80    DoubleUnit(Selection<DoubleUnitSelectionArgs>),
81    SimpleVecMat(Selection<()>),
82    DoubleVecMat(Selection<()>),
83    OrderedDoubleBuffering {
84        selection: Selection<OrderedSelectionArgs>,
85        tile_kind: AcceleratedTileKind,
86    },
87    Naive,
88    #[default]
89    /// Tries using a Simple matmul, then a SimpleUnit if the former failed
90    Auto,
91}
92
93impl Display for Strategy {
94    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
95        match self {
96            Strategy::Simple {
97                read_strategy,
98                selection,
99                tile_kind,
100            } => {
101                f.write_fmt(format_args!("matmul_simple_{read_strategy}_{tile_kind}"))?;
102
103                match selection {
104                    Selection::Forced(_) => f.write_str("_forced_selection")?,
105                    Selection::Inferred(args) => {
106                        if args.multi_rows {
107                            f.write_str("_multirows")?;
108                        }
109                    }
110                };
111            }
112            Strategy::DoubleBuffering {
113                read_strategy,
114                selection,
115                tile_kind,
116            } => {
117                f.write_fmt(format_args!(
118                    "matmul_double_buffering_{read_strategy}_{tile_kind}"
119                ))?;
120
121                match selection {
122                    Selection::Forced(_) => f.write_str("_forced_selection")?,
123                    Selection::Inferred(args) => {
124                        if args.specialized {
125                            f.write_str("_specialized")?;
126                        }
127                    }
128                };
129            }
130            Strategy::Specialized {
131                read_strategy,
132                selection,
133                tile_kind,
134            } => {
135                f.write_fmt(format_args!(
136                    "matmul_specialized_{read_strategy}_{tile_kind}"
137                ))?;
138
139                match selection {
140                    Selection::Forced(_) => f.write_str("_forced_selection")?,
141                    Selection::Inferred(_) => {}
142                };
143            }
144            Strategy::SimpleUnit(selection) => {
145                f.write_fmt(format_args!("matmul_simple_unit"))?;
146
147                match selection {
148                    Selection::Forced(_) => f.write_str("_forced_selection")?,
149                    Selection::Inferred(args) => {
150                        f.write_fmt(format_args!("_{}", args.tile_size))?;
151                    }
152                };
153            }
154            Strategy::DoubleUnit(selection) => {
155                f.write_str("matmul_double_buffering_unit")?;
156
157                match selection {
158                    Selection::Forced(_) => f.write_str("_forced_selection")?,
159                    Selection::Inferred(args) => {
160                        f.write_fmt(format_args!("_{}", args.tile_size))?;
161                    }
162                };
163            }
164            Strategy::SimpleVecMat(selection) => {
165                f.write_str("vecmat_simple")?;
166
167                match selection {
168                    Selection::Forced(_) => f.write_str("_forced_selection")?,
169                    Selection::Inferred(_) => {}
170                };
171            }
172            Strategy::DoubleVecMat(selection) => {
173                f.write_str("vecmat_double_buffering")?;
174
175                match selection {
176                    Selection::Forced(_) => f.write_str("_forced_selection")?,
177                    Selection::Inferred(_) => {}
178                };
179            }
180            Strategy::OrderedDoubleBuffering {
181                selection,
182                tile_kind,
183            } => {
184                f.write_fmt(format_args!("matmul_double_buffering_ordered_{tile_kind}"))?;
185
186                match selection {
187                    Selection::Forced(_) => f.write_str("_forced_selection")?,
188                    Selection::Inferred(args) => {
189                        if let Some(k) = args.partition_k {
190                            f.write_fmt(format_args!("_partition_k{}", k))?;
191                        }
192                        if let Some(r) = args.row_count {
193                            f.write_fmt(format_args!("_row_count{}", r))?;
194                        }
195                        if let Some(r) = args.rows_per_plane {
196                            f.write_fmt(format_args!("_row_per_plane{}", r))?;
197                        }
198                    }
199                };
200            }
201            Strategy::Naive => f.write_str("matmul_naive")?,
202            Strategy::Auto => f.write_str("matmul_auto")?,
203        };
204
205        Ok(())
206    }
207}
208
209#[derive(Debug, Clone, Copy)]
210/// Which reader to use in simple algorithms
211pub enum ReadingStrategy {
212    Cyclic,
213    Strided,
214    Tilewise,
215    AsyncCooperative,
216    AsyncCyclic,
217    Tma,
218}
219
220#[derive(Debug, Clone, Copy)]
221/// Which reader to use in double buffering algorithms
222pub enum PartialReadingStrategy {
223    Cyclic,
224    Tilewise,
225    Hybrid,
226    Tma,
227    AsyncCyclic,
228    AsyncStrided,
229}
230
231#[derive(Debug, Clone, Copy)]
232/// Which reader to use in specialized algorithms
233pub enum AsyncPartialReadingStrategy {
234    Cyclic,
235    Strided,
236    Tma,
237}
238
239#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
240/// Which tile matmul to use for accelerated algorithms
241pub enum AcceleratedTileKind {
242    #[default]
243    Cmma,
244    Mma,
245}
246
247// Display implementations are used to combine and save names when autotuning.
248
249impl Display for AcceleratedTileKind {
250    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
251        match self {
252            AcceleratedTileKind::Cmma => f.write_str("cmma"),
253            AcceleratedTileKind::Mma => f.write_str("mma"),
254        }
255    }
256}
257
258impl Display for ReadingStrategy {
259    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
260        match self {
261            ReadingStrategy::Cyclic => f.write_str("cyclic"),
262            ReadingStrategy::Strided => f.write_str("strided"),
263            ReadingStrategy::Tilewise => f.write_str("tilewise"),
264            ReadingStrategy::AsyncCooperative => f.write_str("async_cooperative"),
265            ReadingStrategy::AsyncCyclic => f.write_str("async_cyclic"),
266            ReadingStrategy::Tma => f.write_str("tma"),
267        }
268    }
269}
270
271impl Display for PartialReadingStrategy {
272    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
273        match self {
274            PartialReadingStrategy::Cyclic => f.write_str("cyclic"),
275            PartialReadingStrategy::Tilewise => f.write_str("tilewise"),
276            PartialReadingStrategy::Hybrid => f.write_str("hybrid"),
277            PartialReadingStrategy::Tma => f.write_str("tma"),
278            PartialReadingStrategy::AsyncCyclic => f.write_str("async_cyclic"),
279            PartialReadingStrategy::AsyncStrided => f.write_str("async_strided"),
280        }
281    }
282}
283
284impl Display for AsyncPartialReadingStrategy {
285    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
286        match self {
287            AsyncPartialReadingStrategy::Cyclic => f.write_str("cyclic"),
288            AsyncPartialReadingStrategy::Strided => f.write_str("strided"),
289            AsyncPartialReadingStrategy::Tma => f.write_str("tma"),
290        }
291    }
292}
293
294macro_rules! with_tile_kind {
295    ($kind: expr, $T: ident, $launch: expr) => {
296        match $kind {
297            AcceleratedTileKind::Cmma => {
298                type $T = CmmaMatmul<Filled>;
299                ($launch)()
300            }
301            AcceleratedTileKind::Mma => {
302                type $T = MmaMatmul;
303                ($launch)()
304            }
305        }
306    };
307}
308
309pub enum MatmulInputHandle<R: Runtime> {
310    Normal(TensorHandle<R>),
311    Quantized {
312        data: TensorHandle<R>,
313        scale: TensorHandle<R>,
314        shape: Vec<usize>,
315        scheme: QuantScheme,
316    },
317}
318
319impl<R: Runtime> MatmulInputHandle<R> {
320    pub fn as_ref(&self) -> MatmulInputHandleRef<'_, R> {
321        match self {
322            MatmulInputHandle::Normal(handle) => {
323                MatmulInputHandleRef::Normal(handle.as_ref(), handle.dtype)
324            }
325            MatmulInputHandle::Quantized {
326                data,
327                scale,
328                shape,
329                scheme,
330            } => MatmulInputHandleRef::Quantized {
331                data: data.as_ref(),
332                scale: scale.as_ref(),
333                data_dtype: data.dtype,
334                scale_dtype: scale.dtype,
335                shape,
336                scheme,
337            },
338        }
339    }
340
341    pub fn from_ref(handle: &MatmulInputHandleRef<'_, R>) -> Self {
342        match handle {
343            MatmulInputHandleRef::Normal(handle, dtype) => {
344                MatmulInputHandle::Normal(TensorHandle::from_ref(handle, *dtype))
345            }
346            MatmulInputHandleRef::Quantized {
347                data,
348                scale,
349                shape,
350                scheme,
351                data_dtype,
352                scale_dtype,
353            } => MatmulInputHandle::Quantized {
354                data: TensorHandle::from_ref(data, *data_dtype),
355                scale: TensorHandle::from_ref(scale, *scale_dtype),
356                shape: shape.to_vec(),
357                scheme: **scheme,
358            },
359        }
360    }
361
362    pub fn data(&self) -> &TensorHandle<R> {
363        match self {
364            MatmulInputHandle::Normal(handle) => handle,
365            MatmulInputHandle::Quantized { data, .. } => data,
366        }
367    }
368
369    pub fn swap_dims(&mut self, dim0: usize, dim1: usize) {
370        match self {
371            MatmulInputHandle::Normal(handle) => {
372                handle.shape.swap(dim0, dim1);
373                handle.strides.swap(dim0, dim1);
374            }
375            MatmulInputHandle::Quantized {
376                data, scale, shape, ..
377            } => {
378                data.shape.swap(dim0, dim1);
379                data.strides.swap(dim0, dim1);
380                if scale.shape.len() == data.shape.len() {
381                    scale.shape.swap(dim0, dim1);
382                    scale.strides.swap(dim0, dim1);
383                }
384                shape.swap(dim0, dim1);
385            }
386        }
387    }
388}
389
390impl<R: Runtime> Clone for MatmulInputHandle<R> {
391    fn clone(&self) -> Self {
392        match self {
393            Self::Normal(handle) => Self::Normal(handle.clone()),
394            Self::Quantized {
395                data,
396                scale,
397                shape,
398                scheme,
399            } => Self::Quantized {
400                data: data.clone(),
401                scale: scale.clone(),
402                shape: shape.clone(),
403                scheme: *scheme,
404            },
405        }
406    }
407}
408
409#[derive(Debug)]
410pub enum MatmulInputHandleRef<'a, R: Runtime> {
411    Normal(TensorHandleRef<'a, R>, StorageType),
412    Quantized {
413        data: TensorHandleRef<'a, R>,
414        data_dtype: StorageType,
415        scale: TensorHandleRef<'a, R>,
416        scale_dtype: StorageType,
417        /// Unpacked shape, excluding padding
418        shape: &'a [usize],
419        scheme: &'a QuantScheme,
420    },
421}
422
423impl<'a, R: Runtime> Clone for MatmulInputHandleRef<'a, R> {
424    fn clone(&self) -> Self {
425        *self
426    }
427}
428
429impl<'a, R: Runtime> Copy for MatmulInputHandleRef<'a, R> {}
430
431impl<'a, R: Runtime> MatmulInputHandleRef<'a, R> {
432    pub fn new(data: TensorHandleRef<'a, R>, dtype: StorageType) -> Self {
433        Self::Normal(data, dtype)
434    }
435
436    pub fn quantized(
437        data: TensorHandleRef<'a, R>,
438        scale: TensorHandleRef<'a, R>,
439        shape: &'a [usize],
440        scheme: &'a QuantScheme,
441        data_dtype: StorageType,
442        scale_dtype: StorageType,
443    ) -> Self {
444        Self::Quantized {
445            data,
446            scale,
447            shape,
448            scheme,
449            data_dtype,
450            scale_dtype,
451        }
452    }
453
454    pub fn data(&self) -> &TensorHandleRef<'a, R> {
455        match self {
456            MatmulInputHandleRef::Normal(handle, ..) => handle,
457            MatmulInputHandleRef::Quantized { data, .. } => data,
458        }
459    }
460
461    pub fn data_mut(&mut self) -> &mut TensorHandleRef<'a, R> {
462        match self {
463            MatmulInputHandleRef::Normal(handle, ..) => handle,
464            MatmulInputHandleRef::Quantized { data, .. } => data,
465        }
466    }
467
468    pub fn scale(&self) -> Option<&TensorHandleRef<'a, R>> {
469        match self {
470            MatmulInputHandleRef::Normal(..) => None,
471            MatmulInputHandleRef::Quantized { scale, .. } => Some(scale),
472        }
473    }
474
475    pub fn scheme(&self) -> Option<&QuantScheme> {
476        match self {
477            MatmulInputHandleRef::Normal(..) => None,
478            MatmulInputHandleRef::Quantized { scheme, .. } => Some(scheme),
479        }
480    }
481
482    pub fn shape(&self) -> &[usize] {
483        match self {
484            MatmulInputHandleRef::Normal(handle, ..) => handle.shape,
485            MatmulInputHandleRef::Quantized { shape, .. } => shape,
486        }
487    }
488
489    pub fn into_contiguous(
490        &self,
491        client: &ComputeClient<R>,
492    ) -> Result<MatmulInputHandle<R>, LaunchError> {
493        let val = match self {
494            MatmulInputHandleRef::Normal(data, dtype) => {
495                MatmulInputHandle::Normal(into_contiguous_pitched(client, data, *dtype)?)
496            }
497            MatmulInputHandleRef::Quantized {
498                data,
499                scale,
500                shape,
501                scheme,
502                data_dtype,
503                scale_dtype,
504            } => {
505                let data = match scheme.store {
506                    // e2m1 has native packing (e2m1x2) so also needs to be re-packed
507                    QuantStore::Native if scheme.value == QuantValue::E2M1 => {
508                        let data = into_contiguous_packed(
509                            client,
510                            data,
511                            shape,
512                            2,
513                            u8::as_type_native_unchecked(),
514                        )?;
515                        // Unsafely cast to E
516                        TensorHandle::from_ref(&data.as_ref(), *data_dtype)
517                    }
518                    QuantStore::U32 => {
519                        let data = into_contiguous_packed(
520                            client,
521                            data,
522                            shape,
523                            scheme.num_quants() as u32,
524                            u32::as_type_native_unchecked(),
525                        )?;
526                        // Unsafely cast to E
527                        TensorHandle::from_ref(&data.as_ref(), *data_dtype)
528                    }
529                    _ => into_contiguous_pitched(client, data, *data_dtype)?,
530                };
531                MatmulInputHandle::Quantized {
532                    data,
533                    scale: TensorHandle::from_ref(scale, *scale_dtype),
534                    shape: shape.to_vec(),
535                    scheme: **scheme,
536                }
537            }
538        };
539
540        Ok(val)
541    }
542}
543
544#[allow(clippy::result_large_err)]
545pub fn launch<R: Runtime>(
546    strategy: &Strategy,
547    client: &ComputeClient<R>,
548    lhs: MatmulInputHandle<R>,
549    rhs: MatmulInputHandle<R>,
550    out: TensorHandle<R>,
551    mut dtypes: MatmulElems,
552) -> Result<(), MatmulSetupError> {
553    launch_ref(
554        strategy,
555        client,
556        &lhs.as_ref(),
557        &rhs.as_ref(),
558        &out.as_ref(),
559        &mut dtypes,
560    )
561}
562
563#[allow(clippy::result_large_err)]
564/// Launches a matrix multiplication kernel..
565///
566/// # Notes
567///
568/// The matmul elements may get changed during selection for improved performance when
569/// the hardware supports it.
570/// Only the inner element types may change such as the stage or register element types.
571pub fn launch_ref<R: Runtime>(
572    strategy: &Strategy,
573    client: &ComputeClient<R>,
574    lhs: &MatmulInputHandleRef<R>,
575    rhs: &MatmulInputHandleRef<R>,
576    out: &TensorHandleRef<R>,
577    dtypes: &mut MatmulElems,
578) -> Result<(), MatmulSetupError> {
579    match strategy {
580        Strategy::Simple {
581            read_strategy,
582            selection,
583            tile_kind,
584        } => with_tile_kind!(tile_kind, Accelerated, || match read_strategy {
585            ReadingStrategy::Cyclic => {
586                layered::launch_ref::<R, SimpleAlgorithm<Accelerated>>(
587                    client, lhs, rhs, out, selection, dtypes,
588                )
589            }
590            ReadingStrategy::Strided => layered::launch_ref::<
591                R,
592                SimpleAlgorithm<
593                    Accelerated,
594                    sync_full_strided::SyncFullStridedLoading,
595                    sync_full_strided::SyncFullStridedLoading,
596                >,
597            >(client, lhs, rhs, out, selection, dtypes),
598            ReadingStrategy::Tilewise => {
599                layered::launch_ref::<
600                    R,
601                    SimpleAlgorithm<
602                        Accelerated,
603                        sync_full_tilewise::SyncFullTilewiseLoading<ColMajorTilingOrder>,
604                        sync_full_tilewise::SyncFullTilewiseLoading<RowMajorTilingOrder>,
605                    >,
606                >(client, lhs, rhs, out, selection, dtypes)
607            }
608            ReadingStrategy::AsyncCooperative => {
609                layered::launch_ref::<
610                    R,
611                    SimpleAlgorithm<
612                        Accelerated,
613                        async_full_cooperative::AsyncFullCooperativeLoading,
614                        async_full_cooperative::AsyncFullCooperativeLoading,
615                    >,
616                >(client, lhs, rhs, out, selection, dtypes)
617            }
618            ReadingStrategy::AsyncCyclic => {
619                layered::launch_ref::<
620                    R,
621                    SimpleAlgorithm<
622                        Accelerated,
623                        async_full_cyclic::AsyncFullCyclicLoading<ColMajorTilingOrder>,
624                        async_full_cyclic::AsyncFullCyclicLoading<RowMajorTilingOrder>,
625                    >,
626                >(client, lhs, rhs, out, selection, dtypes)
627            }
628            ReadingStrategy::Tma => layered::launch_ref_tma::<R, SimpleTmaAlgorithm<Accelerated>>(
629                client, lhs, rhs, out, selection, dtypes
630            ),
631        }),
632        Strategy::DoubleBuffering {
633            read_strategy,
634            selection,
635            tile_kind,
636        } => with_tile_kind!(tile_kind, Accelerated, || match read_strategy {
637            PartialReadingStrategy::Cyclic => {
638                layered::launch_ref::<R, CyclicDoubleBufferingAlgorithm<Accelerated>>(
639                    client, lhs, rhs, out, selection, dtypes,
640                )
641            }
642            PartialReadingStrategy::Tilewise => {
643                layered::launch_ref::<R, TilewiseDoubleBufferingAlgorithm<Accelerated>>(
644                    client, lhs, rhs, out, selection, dtypes,
645                )
646            }
647            PartialReadingStrategy::Hybrid => {
648                layered::launch_ref::<R, HybridDoubleBufferingAlgorithm<Accelerated>>(
649                    client, lhs, rhs, out, selection, dtypes,
650                )
651            }
652            PartialReadingStrategy::Tma => {
653                layered::launch_ref_tma::<R, TmaDoubleBufferingAlgorithm<Accelerated>>(
654                    client, lhs, rhs, out, selection, dtypes,
655                )
656            }
657            PartialReadingStrategy::AsyncCyclic => {
658                layered::launch_ref::<R, AsyncCyclicDoubleBufferingAlgorithm<Accelerated>>(
659                    client, lhs, rhs, out, selection, dtypes,
660                )
661            }
662            PartialReadingStrategy::AsyncStrided => {
663                layered::launch_ref::<R, AsyncStridedDoubleBufferingAlgorithm<Accelerated>>(
664                    client, lhs, rhs, out, selection, dtypes,
665                )
666            }
667        }),
668        Strategy::Specialized {
669            read_strategy,
670            selection,
671            tile_kind,
672        } => with_tile_kind!(tile_kind, Accelerated, || match read_strategy {
673            AsyncPartialReadingStrategy::Cyclic => layered::launch_ref::<
674                R,
675                SpecializedAlgorithm<Accelerated, AsyncPartialCyclicLoading<ColMajorTilingOrder>>,
676            >(
677                client, lhs, rhs, out, selection, dtypes
678            ),
679            AsyncPartialReadingStrategy::Strided =>
680                layered::launch_ref::<
681                    R,
682                    SpecializedAlgorithm<Accelerated, AsyncPartialStridedLoading>,
683                >(client, lhs, rhs, out, selection, dtypes),
684            AsyncPartialReadingStrategy::Tma =>
685                layered::launch_ref_tma::<R, SpecializedAlgorithm<Accelerated>>(
686                    client, lhs, rhs, out, selection, dtypes
687                ),
688        }),
689        Strategy::OrderedDoubleBuffering {
690            selection,
691            tile_kind,
692        } => with_tile_kind!(tile_kind, Accelerated, || layered::launch_ref::<
693            R,
694            OrderedDoubleBufferingAlgorithm<Accelerated>,
695        >(
696            client, lhs, rhs, out, selection, dtypes
697        )),
698        Strategy::SimpleUnit(selection) => {
699            layered::launch_ref::<R, SimpleUnitAlgorithm>(client, lhs, rhs, out, selection, dtypes)
700        }
701        Strategy::DoubleUnit(selection) => {
702            layered::launch_ref::<R, DoubleUnitAlgorithm>(client, lhs, rhs, out, selection, dtypes)
703        }
704        Strategy::Naive => {
705            naive::launch_ref(client, lhs, rhs, out, dtypes)?;
706            Ok(())
707        }
708        Strategy::Auto => {
709            if let Err(err) = layered::launch_ref::<R, SimpleAlgorithm<CmmaMatmul<Filled>>>(
710                client,
711                lhs,
712                rhs,
713                out,
714                &Default::default(),
715                dtypes,
716            ) {
717                match err {
718                    MatmulSetupError::Unavailable(_) => {
719                        layered::launch_ref::<R, SimpleUnitAlgorithm>(
720                            client,
721                            lhs,
722                            rhs,
723                            out,
724                            &Default::default(),
725                            dtypes,
726                        )
727                        .unwrap();
728                    }
729                    _ => panic!("{err:?}"),
730                }
731            }
732
733            Ok(())
734        }
735        Strategy::SimpleVecMat(selection) => layered::launch_ref::<R, SimpleVecMatAlgorithm>(
736            client, lhs, rhs, out, selection, dtypes,
737        ),
738        Strategy::DoubleVecMat(selection) => layered::launch_ref::<R, DoubleVecMatAlgorithm>(
739            client, lhs, rhs, out, selection, dtypes,
740        ),
741    }
742}