cubecl_matmul/
base.rs

1use cubecl_core::{Runtime, client::ComputeClient, prelude::TensorHandleRef};
2
3use cubecl_std::tensor::TensorHandle;
4
5use crate::{
6    components::{MatmulSetupError, tile::accelerated::AcceleratedMatmul},
7    kernels::layered::{
8        Selection,
9        double_buffering::DoubleBufferingArgs,
10        double_unit::{DoubleUnitAlgorithm, DoubleUnitSelectionArgs},
11        ordered_double_buffering::OrderedSelectionArgs,
12        simple::SimpleArgs,
13        simple_unit::SimpleUnitSelectionArgs,
14    },
15};
16
17use super::{
18    components::{
19        MatmulPrecision,
20        global::load::{
21            async_full_cooperative, async_full_cyclic, async_full_maximize_slice_length,
22            async_full_maximize_unit_count, sync_full_strided, sync_full_tilewise,
23        },
24        stage::{ColMajorTilingOrder, RowMajorTilingOrder},
25    },
26    kernels::{
27        layered::{
28            self,
29            double_buffering::{
30                CyclicDoubleBufferingAlgorithm, HybridDoubleBufferingAlgorithm,
31                TilewiseDoubleBufferingAlgorithm,
32            },
33            ordered_double_buffering::OrderedDoubleBufferingAlgorithm,
34            simple::SimpleAlgorithm,
35            simple_barrier::SimpleBarrierAlgorithm,
36            simple_tma::SimpleTmaAlgorithm,
37            simple_unit::SimpleUnitAlgorithm,
38        },
39        naive,
40    },
41};
42
43#[derive(Debug, Clone, Default)]
44/// The matmul algorithm to launch
45///
46/// Most strategies have a selection input that can be overwritten or inferred from minimal information
47/// Some strategies must have a specified loading strategy
48pub enum Strategy {
49    Simple(SyncLoadingStrategy, Selection<SimpleArgs>),
50    SimpleBarrier(AsyncLoadingStrategy),
51    DoubleBuffering(SyncPartialLoadingStrategy, Selection<DoubleBufferingArgs>),
52    SimpleUnit(Selection<SimpleUnitSelectionArgs>),
53    DoubleUnit(Selection<DoubleUnitSelectionArgs>),
54    OrderedDoubleBuffering(Selection<OrderedSelectionArgs>),
55    Naive,
56    #[default]
57    /// Tries using a Simple matmul, then a SimpleUnit if the former failed
58    Auto,
59}
60
61#[derive(Debug, Clone)]
62/// Which loader to use in simple algorithms
63pub enum SyncLoadingStrategy {
64    Cyclic,
65    Strided,
66    Tilewise,
67}
68
69#[derive(Debug, Clone)]
70/// Which loader to use in double buffering algorithms
71pub enum SyncPartialLoadingStrategy {
72    Cyclic,
73    Tilewise,
74    Hybrid,
75}
76
77#[derive(Debug, Clone)]
78/// Which loader to use in barrier algorithm
79pub enum AsyncLoadingStrategy {
80    Cooperative,
81    Cyclic,
82    MaximizeSliceLength,
83    MaximizeUnitCount,
84    Tma,
85}
86
87#[allow(clippy::result_large_err)]
88pub fn launch<R: Runtime, MP: MatmulPrecision>(
89    strategy: &Strategy,
90    client: &ComputeClient<R::Server, R::Channel>,
91    lhs: TensorHandle<R, MP::EI>,
92    lhs_scale: Option<TensorHandle<R, f32>>,
93    rhs: TensorHandle<R, MP::EI>,
94    rhs_scale: Option<TensorHandle<R, f32>>,
95    out: TensorHandle<R, MP::EO>,
96) -> Result<(), MatmulSetupError> {
97    launch_ref::<R, MP>(
98        strategy,
99        client,
100        &lhs.as_ref(),
101        &lhs_scale.as_ref().map(|it| it.as_ref()),
102        &rhs.as_ref(),
103        &rhs_scale.as_ref().map(|it| it.as_ref()),
104        &out.as_ref(),
105    )
106}
107
108#[allow(clippy::result_large_err)]
109pub fn launch_ref<R: Runtime, MP: MatmulPrecision>(
110    strategy: &Strategy,
111    client: &ComputeClient<R::Server, R::Channel>,
112    lhs: &TensorHandleRef<R>,
113    lhs_scale: &Option<TensorHandleRef<R>>,
114    rhs: &TensorHandleRef<R>,
115    rhs_scale: &Option<TensorHandleRef<R>>,
116    out: &TensorHandleRef<R>,
117) -> Result<(), MatmulSetupError> {
118    match strategy {
119        Strategy::Simple(loading_strategy, selection) => match loading_strategy {
120            SyncLoadingStrategy::Cyclic => {
121                layered::launch_ref::<R, MP, SimpleAlgorithm<AcceleratedMatmul>>(
122                    client, lhs, lhs_scale, rhs, rhs_scale, out, selection,
123                )
124            }
125            SyncLoadingStrategy::Strided => {
126                layered::launch_ref::<
127                    R,
128                    MP,
129                    SimpleAlgorithm<
130                        AcceleratedMatmul,
131                        sync_full_strided::SyncFullStridedLoading,
132                        sync_full_strided::SyncFullStridedLoading,
133                    >,
134                >(client, lhs, lhs_scale, rhs, rhs_scale, out, selection)
135            }
136            SyncLoadingStrategy::Tilewise => layered::launch_ref::<
137                R,
138                MP,
139                SimpleAlgorithm<
140                    AcceleratedMatmul,
141                    sync_full_tilewise::SyncFullTilewiseLoading<ColMajorTilingOrder>,
142                    sync_full_tilewise::SyncFullTilewiseLoading<RowMajorTilingOrder>,
143                >,
144            >(
145                client,
146                lhs,
147                lhs_scale,
148                rhs,
149                rhs_scale,
150                out,
151                &Default::default(),
152            ),
153        },
154        Strategy::SimpleBarrier(loading_strategy) => match loading_strategy {
155            AsyncLoadingStrategy::Cooperative => layered::launch_ref::<
156                R,
157                MP,
158                SimpleBarrierAlgorithm<
159                    AcceleratedMatmul,
160                    async_full_cooperative::AsyncFullCooperativeLoading,
161                >,
162            >(
163                client,
164                lhs,
165                lhs_scale,
166                rhs,
167                rhs_scale,
168                out,
169                &Default::default(),
170            ),
171            AsyncLoadingStrategy::Cyclic => layered::launch_ref::<
172                R,
173                MP,
174                SimpleBarrierAlgorithm<
175                    AcceleratedMatmul,
176                    async_full_cyclic::AsyncFullCyclicLoading<ColMajorTilingOrder>,
177                >,
178            >(
179                client,
180                lhs,
181                lhs_scale,
182                rhs,
183                rhs_scale,
184                out,
185                &Default::default(),
186            ),
187            AsyncLoadingStrategy::MaximizeSliceLength => layered::launch_ref::<
188                R,
189                MP,
190                SimpleBarrierAlgorithm<
191                    AcceleratedMatmul,
192                    async_full_maximize_slice_length::AsyncFullMaximizeSliceLengthLoading,
193                >,
194            >(
195                client,
196                lhs,
197                lhs_scale,
198                rhs,
199                rhs_scale,
200                out,
201                &Default::default(),
202            ),
203            AsyncLoadingStrategy::MaximizeUnitCount => layered::launch_ref::<
204                R,
205                MP,
206                SimpleBarrierAlgorithm<
207                    AcceleratedMatmul,
208                    async_full_maximize_unit_count::AsyncFullMaximizeUnitCountLoading,
209                >,
210            >(
211                client,
212                lhs,
213                lhs_scale,
214                rhs,
215                rhs_scale,
216                out,
217                &Default::default(),
218            ),
219            AsyncLoadingStrategy::Tma => {
220                layered::matmul_cmma_tma_ref_no_check::<R, MP, SimpleTmaAlgorithm<AcceleratedMatmul>>(
221                    client,
222                    lhs,
223                    lhs_scale,
224                    rhs,
225                    rhs_scale,
226                    out,
227                    (false, false),
228                    &Default::default(),
229                )
230            }
231        },
232        Strategy::DoubleBuffering(loading_strategy, selection) => match loading_strategy {
233            SyncPartialLoadingStrategy::Cyclic => {
234                layered::launch_ref::<R, MP, CyclicDoubleBufferingAlgorithm<AcceleratedMatmul>>(
235                    client, lhs, lhs_scale, rhs, rhs_scale, out, selection,
236                )
237            }
238            SyncPartialLoadingStrategy::Tilewise => {
239                layered::launch_ref::<R, MP, TilewiseDoubleBufferingAlgorithm<AcceleratedMatmul>>(
240                    client, lhs, lhs_scale, rhs, rhs_scale, out, selection,
241                )
242            }
243            SyncPartialLoadingStrategy::Hybrid => {
244                layered::launch_ref::<R, MP, HybridDoubleBufferingAlgorithm<AcceleratedMatmul>>(
245                    client, lhs, lhs_scale, rhs, rhs_scale, out, selection,
246                )
247            }
248        },
249        Strategy::OrderedDoubleBuffering(selection) => {
250            layered::launch_ref::<R, MP, OrderedDoubleBufferingAlgorithm<AcceleratedMatmul>>(
251                client, lhs, lhs_scale, rhs, rhs_scale, out, selection,
252            )
253        }
254        Strategy::SimpleUnit(selection) => layered::launch_ref::<R, MP, SimpleUnitAlgorithm>(
255            client, lhs, lhs_scale, rhs, rhs_scale, out, selection,
256        ),
257        Strategy::DoubleUnit(selection) => layered::launch_ref::<R, MP, DoubleUnitAlgorithm>(
258            client, lhs, lhs_scale, rhs, rhs_scale, out, selection,
259        ),
260        Strategy::Naive => {
261            // TODO Implement naive with EI and EO
262            naive::launch_ref::<R, MP::EI>(client, lhs, rhs, out)?;
263            Ok(())
264        }
265        Strategy::Auto => {
266            if let Err(err) = layered::launch_ref::<R, MP, SimpleAlgorithm<AcceleratedMatmul>>(
267                client,
268                lhs,
269                lhs_scale,
270                rhs,
271                rhs_scale,
272                out,
273                &Default::default(),
274            ) {
275                match err {
276                    MatmulSetupError::Unavailable(_) => {
277                        layered::launch_ref::<R, MP, SimpleUnitAlgorithm>(
278                            client,
279                            lhs,
280                            lhs_scale,
281                            rhs,
282                            rhs_scale,
283                            out,
284                            &Default::default(),
285                        )
286                        .unwrap();
287                    }
288                    _ => panic!("{err:?}"),
289                }
290            }
291
292            Ok(())
293        }
294    }
295}