cubecl_matmul/
base.rs

1use cubecl_common::quant::scheme::{QuantScheme, QuantStore, QuantValue};
2use cubecl_core::{
3    Runtime,
4    client::ComputeClient,
5    prelude::{CubePrimitive, Numeric, TensorHandleRef},
6};
7
8use cubecl_std::tensor::{TensorHandle, into_contiguous_packed, into_contiguous_pitched};
9
10use crate::{
11    components::{
12        AccG, LhsG, MatmulSetupError, RhsG,
13        tile::{accelerated::AcceleratedMatmul, io::Filled},
14    },
15    kernels::layered::{
16        Selection,
17        double_buffering::DoubleBufferingArgs,
18        double_unit::{DoubleUnitAlgorithm, DoubleUnitSelectionArgs},
19        ordered_double_buffering::OrderedSelectionArgs,
20        simple::SimpleArgs,
21        simple_unit::SimpleUnitSelectionArgs,
22        vecmat::{DoubleVecMatAlgorithm, SimpleVecMatAlgorithm},
23    },
24};
25
26use super::{
27    components::{
28        MatmulPrecision,
29        global::read::{
30            async_full_cooperative, async_full_cyclic, async_full_maximize_slice_length,
31            async_full_maximize_unit_count, sync_full_strided, sync_full_tilewise,
32        },
33        stage::{ColMajorTilingOrder, RowMajorTilingOrder},
34    },
35    kernels::{
36        layered::{
37            self,
38            double_buffering::{
39                CyclicDoubleBufferingAlgorithm, HybridDoubleBufferingAlgorithm,
40                TilewiseDoubleBufferingAlgorithm,
41            },
42            ordered_double_buffering::OrderedDoubleBufferingAlgorithm,
43            simple::SimpleAlgorithm,
44            simple_barrier::SimpleBarrierAlgorithm,
45            simple_tma::SimpleTmaAlgorithm,
46            simple_unit::SimpleUnitAlgorithm,
47        },
48        naive,
49    },
50};
51
52#[derive(Debug, Clone, Default)]
53/// The matmul algorithm to launch
54///
55/// Most strategies have a selection input that can be overwritten or inferred from minimal information
56/// Some strategies must have a specified loading strategy
57pub enum Strategy {
58    Simple(SyncReadingStrategy, Selection<SimpleArgs>),
59    SimpleBarrier(AsyncReadingStrategy),
60    DoubleBuffering(SyncPartialReadingStrategy, Selection<DoubleBufferingArgs>),
61    SimpleUnit(Selection<SimpleUnitSelectionArgs>),
62    DoubleUnit(Selection<DoubleUnitSelectionArgs>),
63    SimpleVecMat(Selection<()>),
64    DoubleVecMat(Selection<()>),
65    OrderedDoubleBuffering(Selection<OrderedSelectionArgs>),
66    Naive,
67    #[default]
68    /// Tries using a Simple matmul, then a SimpleUnit if the former failed
69    Auto,
70}
71
72#[derive(Debug, Clone)]
73/// Which reader to use in simple algorithms
74pub enum SyncReadingStrategy {
75    Cyclic,
76    Strided,
77    Tilewise,
78}
79
80#[derive(Debug, Clone)]
81/// Which reader to use in double buffering algorithms
82pub enum SyncPartialReadingStrategy {
83    Cyclic,
84    Tilewise,
85    Hybrid,
86}
87
88#[derive(Debug, Clone)]
89/// Which reader to use in barrier algorithm
90pub enum AsyncReadingStrategy {
91    Cooperative,
92    Cyclic,
93    MaximizeSliceLength,
94    MaximizeUnitCount,
95    Tma,
96}
97
98pub enum MatmulInputHandle<R: Runtime, E: CubePrimitive, S: CubePrimitive = f32> {
99    Normal(TensorHandle<R, E>),
100    Quantized {
101        data: TensorHandle<R, E>,
102        scale: TensorHandle<R, S>,
103        shape: Vec<usize>,
104        scheme: QuantScheme,
105    },
106}
107
108impl<R: Runtime, E: Numeric> MatmulInputHandle<R, E> {
109    pub fn as_ref(&self) -> MatmulInputHandleRef<'_, R> {
110        match self {
111            MatmulInputHandle::Normal(handle) => MatmulInputHandleRef::Normal(handle.as_ref()),
112            MatmulInputHandle::Quantized {
113                data,
114                scale,
115                shape,
116                scheme,
117            } => MatmulInputHandleRef::Quantized {
118                data: data.as_ref(),
119                scale: scale.as_ref(),
120                shape,
121                scheme,
122            },
123        }
124    }
125
126    pub fn from_ref(handle: &MatmulInputHandleRef<'_, R>) -> Self {
127        match handle {
128            MatmulInputHandleRef::Normal(handle) => {
129                MatmulInputHandle::Normal(TensorHandle::from_ref(handle))
130            }
131            MatmulInputHandleRef::Quantized {
132                data,
133                scale,
134                shape,
135                scheme,
136            } => MatmulInputHandle::Quantized {
137                data: TensorHandle::from_ref(data),
138                scale: TensorHandle::from_ref(scale),
139                shape: shape.to_vec(),
140                scheme: **scheme,
141            },
142        }
143    }
144
145    pub fn data(&self) -> &TensorHandle<R, E> {
146        match self {
147            MatmulInputHandle::Normal(handle) => handle,
148            MatmulInputHandle::Quantized { data, .. } => data,
149        }
150    }
151
152    pub fn swap_dims(&mut self, dim0: usize, dim1: usize) {
153        match self {
154            MatmulInputHandle::Normal(handle) => {
155                handle.shape.swap(dim0, dim1);
156                handle.strides.swap(dim0, dim1);
157            }
158            MatmulInputHandle::Quantized {
159                data, scale, shape, ..
160            } => {
161                data.shape.swap(dim0, dim1);
162                data.strides.swap(dim0, dim1);
163                if scale.shape.len() == data.shape.len() {
164                    scale.shape.swap(dim0, dim1);
165                    scale.strides.swap(dim0, dim1);
166                }
167                shape.swap(dim0, dim1);
168            }
169        }
170    }
171}
172
173impl<R: Runtime, E: CubePrimitive> Clone for MatmulInputHandle<R, E> {
174    fn clone(&self) -> Self {
175        match self {
176            Self::Normal(handle) => Self::Normal(handle.clone()),
177            Self::Quantized {
178                data,
179                scale,
180                shape,
181                scheme,
182            } => Self::Quantized {
183                data: data.clone(),
184                scale: scale.clone(),
185                shape: shape.clone(),
186                scheme: *scheme,
187            },
188        }
189    }
190}
191
192#[derive(Debug)]
193pub enum MatmulInputHandleRef<'a, R: Runtime> {
194    Normal(TensorHandleRef<'a, R>),
195    Quantized {
196        data: TensorHandleRef<'a, R>,
197        scale: TensorHandleRef<'a, R>,
198        /// Unpacked shape, excluding padding
199        shape: &'a [usize],
200        scheme: &'a QuantScheme,
201    },
202}
203
204impl<'a, R: Runtime> Clone for MatmulInputHandleRef<'a, R> {
205    fn clone(&self) -> Self {
206        *self
207    }
208}
209
210impl<'a, R: Runtime> Copy for MatmulInputHandleRef<'a, R> {}
211
212impl<'a, R: Runtime> MatmulInputHandleRef<'a, R> {
213    pub fn new(data: TensorHandleRef<'a, R>) -> Self {
214        Self::Normal(data)
215    }
216
217    pub fn quantized(
218        data: TensorHandleRef<'a, R>,
219        scale: TensorHandleRef<'a, R>,
220        shape: &'a [usize],
221        scheme: &'a QuantScheme,
222    ) -> Self {
223        Self::Quantized {
224            data,
225            scale,
226            shape,
227            scheme,
228        }
229    }
230
231    pub fn data(&self) -> &TensorHandleRef<'a, R> {
232        match self {
233            MatmulInputHandleRef::Normal(handle) => handle,
234            MatmulInputHandleRef::Quantized { data, .. } => data,
235        }
236    }
237
238    pub fn data_mut(&mut self) -> &mut TensorHandleRef<'a, R> {
239        match self {
240            MatmulInputHandleRef::Normal(handle) => handle,
241            MatmulInputHandleRef::Quantized { data, .. } => data,
242        }
243    }
244
245    pub fn scale(&self) -> Option<&TensorHandleRef<'a, R>> {
246        match self {
247            MatmulInputHandleRef::Normal(_) => None,
248            MatmulInputHandleRef::Quantized { scale, .. } => Some(scale),
249        }
250    }
251
252    pub fn scheme(&self) -> Option<&QuantScheme> {
253        match self {
254            MatmulInputHandleRef::Normal(_) => None,
255            MatmulInputHandleRef::Quantized { scheme, .. } => Some(scheme),
256        }
257    }
258
259    pub fn shape(&self) -> &[usize] {
260        match self {
261            MatmulInputHandleRef::Normal(handle) => handle.shape,
262            MatmulInputHandleRef::Quantized { shape, .. } => shape,
263        }
264    }
265
266    pub fn into_contiguous<E: Numeric>(
267        &self,
268        client: &ComputeClient<R::Server>,
269    ) -> MatmulInputHandle<R, E> {
270        match self {
271            MatmulInputHandleRef::Normal(data) => {
272                MatmulInputHandle::Normal(into_contiguous_pitched::<R, E>(client, data))
273            }
274            MatmulInputHandleRef::Quantized {
275                data,
276                scale,
277                shape,
278                scheme,
279            } => {
280                let data = match scheme.store {
281                    // e2m1 has native packing (e2m1x2) so also needs to be re-packed
282                    QuantStore::Native if scheme.value == QuantValue::E2M1 => {
283                        let data = into_contiguous_packed::<R, u8>(client, data, shape, 2);
284                        // Unsafely cast to E
285                        TensorHandle::from_ref(&data.as_ref())
286                    }
287                    QuantStore::U32 => {
288                        let data = into_contiguous_packed::<R, u32>(
289                            client,
290                            data,
291                            shape,
292                            scheme.num_quants() as u32,
293                        );
294                        // Unsafely cast to E
295                        TensorHandle::from_ref(&data.as_ref())
296                    }
297                    _ => into_contiguous_pitched::<R, E>(client, data),
298                };
299                MatmulInputHandle::Quantized {
300                    data,
301                    scale: TensorHandle::from_ref(scale),
302                    shape: shape.to_vec(),
303                    scheme: **scheme,
304                }
305            }
306        }
307    }
308}
309
310#[allow(clippy::result_large_err)]
311pub fn launch<R: Runtime, MP: MatmulPrecision>(
312    strategy: &Strategy,
313    client: &ComputeClient<R::Server>,
314    lhs: MatmulInputHandle<R, LhsG<MP>>,
315    rhs: MatmulInputHandle<R, RhsG<MP>>,
316    out: TensorHandle<R, AccG<MP>>,
317) -> Result<(), MatmulSetupError> {
318    launch_ref::<R, MP>(
319        strategy,
320        client,
321        &lhs.as_ref(),
322        &rhs.as_ref(),
323        &out.as_ref(),
324    )
325}
326
327#[allow(clippy::result_large_err)]
328pub fn launch_ref<R: Runtime, MP: MatmulPrecision>(
329    strategy: &Strategy,
330    client: &ComputeClient<R::Server>,
331    lhs: &MatmulInputHandleRef<R>,
332    rhs: &MatmulInputHandleRef<R>,
333    out: &TensorHandleRef<R>,
334) -> Result<(), MatmulSetupError> {
335    type Accelerated = AcceleratedMatmul<Filled>;
336
337    match strategy {
338        Strategy::Simple(loading_strategy, selection) => match loading_strategy {
339            SyncReadingStrategy::Cyclic => {
340                layered::launch_ref::<R, MP, SimpleAlgorithm<Accelerated>>(
341                    client, lhs, rhs, out, selection,
342                )
343            }
344            SyncReadingStrategy::Strided => layered::launch_ref::<
345                R,
346                MP,
347                SimpleAlgorithm<
348                    Accelerated,
349                    sync_full_strided::SyncFullStridedLoading,
350                    sync_full_strided::SyncFullStridedLoading,
351                >,
352            >(client, lhs, rhs, out, selection),
353            SyncReadingStrategy::Tilewise => {
354                layered::launch_ref::<
355                    R,
356                    MP,
357                    SimpleAlgorithm<
358                        Accelerated,
359                        sync_full_tilewise::SyncFullTilewiseLoading<ColMajorTilingOrder>,
360                        sync_full_tilewise::SyncFullTilewiseLoading<RowMajorTilingOrder>,
361                    >,
362                >(client, lhs, rhs, out, &Default::default())
363            }
364        },
365        Strategy::SimpleBarrier(loading_strategy) => match loading_strategy {
366            AsyncReadingStrategy::Cooperative => {
367                layered::launch_ref::<
368                    R,
369                    MP,
370                    SimpleBarrierAlgorithm<
371                        Accelerated,
372                        async_full_cooperative::AsyncFullCooperativeLoading,
373                    >,
374                >(client, lhs, rhs, out, &Default::default())
375            }
376            AsyncReadingStrategy::Cyclic => {
377                layered::launch_ref::<
378                    R,
379                    MP,
380                    SimpleBarrierAlgorithm<
381                        Accelerated,
382                        async_full_cyclic::AsyncFullCyclicLoading<ColMajorTilingOrder>,
383                    >,
384                >(client, lhs, rhs, out, &Default::default())
385            }
386            AsyncReadingStrategy::MaximizeSliceLength => {
387                layered::launch_ref::<
388                    R,
389                    MP,
390                    SimpleBarrierAlgorithm<
391                        Accelerated,
392                        async_full_maximize_slice_length::AsyncFullMaximizeSliceLengthLoading,
393                    >,
394                >(client, lhs, rhs, out, &Default::default())
395            }
396            AsyncReadingStrategy::MaximizeUnitCount => {
397                layered::launch_ref::<
398                    R,
399                    MP,
400                    SimpleBarrierAlgorithm<
401                        Accelerated,
402                        async_full_maximize_unit_count::AsyncFullMaximizeUnitCountLoading,
403                    >,
404                >(client, lhs, rhs, out, &Default::default())
405            }
406            AsyncReadingStrategy::Tma => {
407                layered::matmul_cmma_tma_ref_no_check::<R, MP, SimpleTmaAlgorithm<Accelerated>>(
408                    client,
409                    lhs,
410                    rhs,
411                    out,
412                    (false, false),
413                    &Default::default(),
414                )
415            }
416        },
417        Strategy::DoubleBuffering(loading_strategy, selection) => match loading_strategy {
418            SyncPartialReadingStrategy::Cyclic => {
419                layered::launch_ref::<R, MP, CyclicDoubleBufferingAlgorithm<Accelerated>>(
420                    client, lhs, rhs, out, selection,
421                )
422            }
423            SyncPartialReadingStrategy::Tilewise => {
424                layered::launch_ref::<R, MP, TilewiseDoubleBufferingAlgorithm<Accelerated>>(
425                    client, lhs, rhs, out, selection,
426                )
427            }
428            SyncPartialReadingStrategy::Hybrid => {
429                layered::launch_ref::<R, MP, HybridDoubleBufferingAlgorithm<Accelerated>>(
430                    client, lhs, rhs, out, selection,
431                )
432            }
433        },
434        Strategy::OrderedDoubleBuffering(selection) => {
435            layered::launch_ref::<R, MP, OrderedDoubleBufferingAlgorithm<Accelerated>>(
436                client, lhs, rhs, out, selection,
437            )
438        }
439        Strategy::SimpleUnit(selection) => {
440            layered::launch_ref::<R, MP, SimpleUnitAlgorithm>(client, lhs, rhs, out, selection)
441        }
442        Strategy::DoubleUnit(selection) => {
443            layered::launch_ref::<R, MP, DoubleUnitAlgorithm>(client, lhs, rhs, out, selection)
444        }
445        Strategy::Naive => {
446            naive::launch_ref::<R, LhsG<MP>, AccG<MP>>(client, lhs, rhs, out)?;
447            Ok(())
448        }
449        Strategy::Auto => {
450            if let Err(err) = layered::launch_ref::<R, MP, SimpleAlgorithm<Accelerated>>(
451                client,
452                lhs,
453                rhs,
454                out,
455                &Default::default(),
456            ) {
457                match err {
458                    MatmulSetupError::Unavailable(_) => {
459                        layered::launch_ref::<R, MP, SimpleUnitAlgorithm>(
460                            client,
461                            lhs,
462                            rhs,
463                            out,
464                            &Default::default(),
465                        )
466                        .unwrap();
467                    }
468                    _ => panic!("{err:?}"),
469                }
470            }
471
472            Ok(())
473        }
474        Strategy::SimpleVecMat(selection) => {
475            layered::launch_ref::<R, MP, SimpleVecMatAlgorithm>(client, lhs, rhs, out, selection)
476        }
477        Strategy::DoubleVecMat(selection) => {
478            layered::launch_ref::<R, MP, DoubleVecMatAlgorithm>(client, lhs, rhs, out, selection)
479        }
480    }
481}