cubecl_linalg/matmul/
base.rs

1use cubecl_core::{Runtime, client::ComputeClient, prelude::TensorHandleRef};
2
3use crate::tensor::TensorHandle;
4
5use super::{
6    components::{
7        MatmulPrecision,
8        global::load::{
9            async_full_cooperative, async_full_cyclic, async_full_maximize_slice_length,
10            async_full_maximize_unit_count, sync_full_strided, sync_full_tilewise,
11        },
12        stage::{ColMajorTilingOrder, RowMajorTilingOrder},
13        tile::accelerated::Accelerated,
14    },
15    kernels::{
16        MatmulLaunchError,
17        matmul::{
18            self, double_buffering::DoubleBufferingAlgorithm,
19            double_buffering_barrier::DoubleBufferingBarrierAlgorithm, simple::SimpleAlgorithm,
20            simple_barrier::SimpleBarrierAlgorithm, simple_pipelined::SimplePipelinedAlgorithm,
21            simple_tma::SimpleTmaAlgorithm, specialized::SpecializedAlgorithm,
22        },
23        naive,
24        tiling2d::{self, Tiling2dConfig},
25    },
26};
27
28#[derive(Debug, Clone, Default)]
29pub enum Strategy {
30    Simple(SyncLoadingStrategy),
31    SimpleBarrier(AsyncLoadingStrategy),
32    SimplePipelined,
33    DoubleBuffering,
34    DoubleBufferingBarrier,
35    Specialized,
36    Naive,
37    Tiling2D(Tiling2dConfig),
38    #[default]
39    Auto,
40}
41
42#[derive(Debug, Clone)]
43pub enum SyncLoadingStrategy {
44    Cyclic,
45    Strided,
46    Tilewise,
47}
48
49#[derive(Debug, Clone)]
50pub enum AsyncLoadingStrategy {
51    Cooperative,
52    Cyclic,
53    MaximizeSliceLength,
54    MaximizeUnitCount,
55    Tma,
56}
57
58#[allow(clippy::result_large_err)]
59pub fn launch<R: Runtime, MP: MatmulPrecision>(
60    strategy: &Strategy,
61    client: &ComputeClient<R::Server, R::Channel>,
62    lhs: TensorHandle<R, MP::EI>,
63    rhs: TensorHandle<R, MP::EI>,
64    out: TensorHandle<R, MP::EO>,
65) -> Result<(), MatmulLaunchError> {
66    launch_ref::<R, MP>(
67        strategy,
68        client,
69        &lhs.as_ref(),
70        &rhs.as_ref(),
71        &out.as_ref(),
72    )
73}
74
75#[allow(clippy::result_large_err)]
76pub fn launch_ref<R: Runtime, MP: MatmulPrecision>(
77    strategy: &Strategy,
78    client: &ComputeClient<R::Server, R::Channel>,
79    lhs: &TensorHandleRef<R>,
80    rhs: &TensorHandleRef<R>,
81    out: &TensorHandleRef<R>,
82) -> Result<(), MatmulLaunchError> {
83    match strategy {
84        Strategy::Simple(loading_strategy) => match loading_strategy {
85            SyncLoadingStrategy::Cyclic => {
86                matmul::launch_ref::<R, MP, SimpleAlgorithm<Accelerated>>(client, lhs, rhs, out)
87            }
88            SyncLoadingStrategy::Strided => matmul::launch_ref::<
89                R,
90                MP,
91                SimpleAlgorithm<
92                    Accelerated,
93                    sync_full_strided::LoadingStrategy,
94                    sync_full_strided::LoadingStrategy,
95                >,
96            >(client, lhs, rhs, out),
97            SyncLoadingStrategy::Tilewise => matmul::launch_ref::<
98                R,
99                MP,
100                SimpleAlgorithm<
101                    Accelerated,
102                    sync_full_tilewise::LoadingStrategy<ColMajorTilingOrder>,
103                    sync_full_tilewise::LoadingStrategy<RowMajorTilingOrder>,
104                >,
105            >(client, lhs, rhs, out),
106        },
107        Strategy::SimpleBarrier(loading_strategy) => match loading_strategy {
108            AsyncLoadingStrategy::Cooperative => matmul::launch_ref::<
109                R,
110                MP,
111                SimpleBarrierAlgorithm<Accelerated, async_full_cooperative::LoadingStrategy>,
112            >(client, lhs, rhs, out),
113            AsyncLoadingStrategy::Cyclic => matmul::launch_ref::<
114                R,
115                MP,
116                SimpleBarrierAlgorithm<
117                    Accelerated,
118                    async_full_cyclic::LoadingStrategy<ColMajorTilingOrder>,
119                >,
120            >(client, lhs, rhs, out),
121            AsyncLoadingStrategy::MaximizeSliceLength => matmul::launch_ref::<
122                R,
123                MP,
124                SimpleBarrierAlgorithm<
125                    Accelerated,
126                    async_full_maximize_slice_length::LoadingStrategy,
127                >,
128            >(client, lhs, rhs, out),
129            AsyncLoadingStrategy::MaximizeUnitCount => matmul::launch_ref::<
130                R,
131                MP,
132                SimpleBarrierAlgorithm<
133                    Accelerated,
134                    async_full_maximize_unit_count::LoadingStrategy,
135                >,
136            >(client, lhs, rhs, out),
137            AsyncLoadingStrategy::Tma => matmul::matmul_cmma_tma_ref_no_check::<
138                R,
139                MP,
140                SimpleTmaAlgorithm<Accelerated>,
141            >(client, lhs, rhs, out, (false, false)),
142        },
143        Strategy::SimplePipelined => {
144            matmul::launch_ref::<R, MP, SimplePipelinedAlgorithm<Accelerated>>(
145                client, lhs, rhs, out,
146            )
147        }
148        Strategy::DoubleBuffering => {
149            matmul::launch_ref::<R, MP, DoubleBufferingAlgorithm<Accelerated>>(
150                client, lhs, rhs, out,
151            )
152        }
153        Strategy::DoubleBufferingBarrier => {
154            matmul::launch_ref::<R, MP, DoubleBufferingBarrierAlgorithm<Accelerated>>(
155                client, lhs, rhs, out,
156            )
157        }
158        Strategy::Specialized => {
159            matmul::launch_ref::<R, MP, SpecializedAlgorithm<Accelerated>>(client, lhs, rhs, out)
160        }
161        Strategy::Tiling2D(config) => {
162            // TODO Implement tiling2d with EI and EO
163            tiling2d::launch_ref::<R, MP::EI>(client, lhs, rhs, out, config.clone());
164            Ok(())
165        }
166        Strategy::Naive => {
167            // TODO Implement naive with EI and EO
168            naive::launch_ref::<R, MP::EI>(client, lhs, rhs, out)?;
169            Ok(())
170        }
171        Strategy::Auto => {
172            if let Err(err) =
173                matmul::launch_ref::<R, MP, SimpleAlgorithm<Accelerated>>(client, lhs, rhs, out)
174            {
175                match err {
176                    super::kernels::MatmulLaunchError::Unavailable(_) => {
177                        // TODO Implement naive with EI and EO
178                        tiling2d::launch_ref::<R, MP::EI>(
179                            client,
180                            lhs,
181                            rhs,
182                            out,
183                            Tiling2dConfig::default(),
184                        )
185                    }
186                    _ => panic!("{err:?}"),
187                }
188            }
189
190            Ok(())
191        }
192    }
193}