1use cubecl_core::{Runtime, client::ComputeClient, prelude::TensorHandleRef};
2
3use crate::tensor::TensorHandle;
4
5use super::{
6 components::{
7 MatmulPrecision,
8 global::load::{
9 async_full_cooperative, async_full_cyclic, async_full_maximize_slice_length,
10 async_full_maximize_unit_count, sync_full_strided, sync_full_tilewise,
11 },
12 stage::{ColMajorTilingOrder, RowMajorTilingOrder},
13 tile::accelerated::Accelerated,
14 },
15 kernels::{
16 MatmulLaunchError,
17 matmul::{
18 self, double_buffering::DoubleBufferingAlgorithm,
19 double_buffering_barrier::DoubleBufferingBarrierAlgorithm, simple::SimpleAlgorithm,
20 simple_barrier::SimpleBarrierAlgorithm, simple_pipelined::SimplePipelinedAlgorithm,
21 simple_tma::SimpleTmaAlgorithm, specialized::SpecializedAlgorithm,
22 },
23 naive,
24 tiling2d::{self, Tiling2dConfig},
25 },
26};
27
28#[derive(Debug, Clone, Default)]
29pub enum Strategy {
30 Simple(SyncLoadingStrategy),
31 SimpleBarrier(AsyncLoadingStrategy),
32 SimplePipelined,
33 DoubleBuffering,
34 DoubleBufferingBarrier,
35 Specialized,
36 Naive,
37 Tiling2D(Tiling2dConfig),
38 #[default]
39 Auto,
40}
41
42#[derive(Debug, Clone)]
43pub enum SyncLoadingStrategy {
44 Cyclic,
45 Strided,
46 Tilewise,
47}
48
49#[derive(Debug, Clone)]
50pub enum AsyncLoadingStrategy {
51 Cooperative,
52 Cyclic,
53 MaximizeSliceLength,
54 MaximizeUnitCount,
55 Tma,
56}
57
58#[allow(clippy::result_large_err)]
59pub fn launch<R: Runtime, MP: MatmulPrecision>(
60 strategy: &Strategy,
61 client: &ComputeClient<R::Server, R::Channel>,
62 lhs: TensorHandle<R, MP::EI>,
63 rhs: TensorHandle<R, MP::EI>,
64 out: TensorHandle<R, MP::EO>,
65) -> Result<(), MatmulLaunchError> {
66 launch_ref::<R, MP>(
67 strategy,
68 client,
69 &lhs.as_ref(),
70 &rhs.as_ref(),
71 &out.as_ref(),
72 )
73}
74
75#[allow(clippy::result_large_err)]
76pub fn launch_ref<R: Runtime, MP: MatmulPrecision>(
77 strategy: &Strategy,
78 client: &ComputeClient<R::Server, R::Channel>,
79 lhs: &TensorHandleRef<R>,
80 rhs: &TensorHandleRef<R>,
81 out: &TensorHandleRef<R>,
82) -> Result<(), MatmulLaunchError> {
83 match strategy {
84 Strategy::Simple(loading_strategy) => match loading_strategy {
85 SyncLoadingStrategy::Cyclic => {
86 matmul::launch_ref::<R, MP, SimpleAlgorithm<Accelerated>>(client, lhs, rhs, out)
87 }
88 SyncLoadingStrategy::Strided => matmul::launch_ref::<
89 R,
90 MP,
91 SimpleAlgorithm<
92 Accelerated,
93 sync_full_strided::LoadingStrategy,
94 sync_full_strided::LoadingStrategy,
95 >,
96 >(client, lhs, rhs, out),
97 SyncLoadingStrategy::Tilewise => matmul::launch_ref::<
98 R,
99 MP,
100 SimpleAlgorithm<
101 Accelerated,
102 sync_full_tilewise::LoadingStrategy<ColMajorTilingOrder>,
103 sync_full_tilewise::LoadingStrategy<RowMajorTilingOrder>,
104 >,
105 >(client, lhs, rhs, out),
106 },
107 Strategy::SimpleBarrier(loading_strategy) => match loading_strategy {
108 AsyncLoadingStrategy::Cooperative => matmul::launch_ref::<
109 R,
110 MP,
111 SimpleBarrierAlgorithm<Accelerated, async_full_cooperative::LoadingStrategy>,
112 >(client, lhs, rhs, out),
113 AsyncLoadingStrategy::Cyclic => matmul::launch_ref::<
114 R,
115 MP,
116 SimpleBarrierAlgorithm<
117 Accelerated,
118 async_full_cyclic::LoadingStrategy<ColMajorTilingOrder>,
119 >,
120 >(client, lhs, rhs, out),
121 AsyncLoadingStrategy::MaximizeSliceLength => matmul::launch_ref::<
122 R,
123 MP,
124 SimpleBarrierAlgorithm<
125 Accelerated,
126 async_full_maximize_slice_length::LoadingStrategy,
127 >,
128 >(client, lhs, rhs, out),
129 AsyncLoadingStrategy::MaximizeUnitCount => matmul::launch_ref::<
130 R,
131 MP,
132 SimpleBarrierAlgorithm<
133 Accelerated,
134 async_full_maximize_unit_count::LoadingStrategy,
135 >,
136 >(client, lhs, rhs, out),
137 AsyncLoadingStrategy::Tma => matmul::matmul_cmma_tma_ref_no_check::<
138 R,
139 MP,
140 SimpleTmaAlgorithm<Accelerated>,
141 >(client, lhs, rhs, out, (false, false)),
142 },
143 Strategy::SimplePipelined => {
144 matmul::launch_ref::<R, MP, SimplePipelinedAlgorithm<Accelerated>>(
145 client, lhs, rhs, out,
146 )
147 }
148 Strategy::DoubleBuffering => {
149 matmul::launch_ref::<R, MP, DoubleBufferingAlgorithm<Accelerated>>(
150 client, lhs, rhs, out,
151 )
152 }
153 Strategy::DoubleBufferingBarrier => {
154 matmul::launch_ref::<R, MP, DoubleBufferingBarrierAlgorithm<Accelerated>>(
155 client, lhs, rhs, out,
156 )
157 }
158 Strategy::Specialized => {
159 matmul::launch_ref::<R, MP, SpecializedAlgorithm<Accelerated>>(client, lhs, rhs, out)
160 }
161 Strategy::Tiling2D(config) => {
162 tiling2d::launch_ref::<R, MP::EI>(client, lhs, rhs, out, config.clone());
164 Ok(())
165 }
166 Strategy::Naive => {
167 naive::launch_ref::<R, MP::EI>(client, lhs, rhs, out)?;
169 Ok(())
170 }
171 Strategy::Auto => {
172 if let Err(err) =
173 matmul::launch_ref::<R, MP, SimpleAlgorithm<Accelerated>>(client, lhs, rhs, out)
174 {
175 match err {
176 super::kernels::MatmulLaunchError::Unavailable(_) => {
177 tiling2d::launch_ref::<R, MP::EI>(
179 client,
180 lhs,
181 rhs,
182 out,
183 Tiling2dConfig::default(),
184 )
185 }
186 _ => panic!("{err:?}"),
187 }
188 }
189
190 Ok(())
191 }
192 }
193}