1use cubecl::{Runtime, client::ComputeClient};
8use cubek_matmul::{
9 components::tile::TileMatmulKind,
10 definition::{MatmulElems, TilingBlueprint},
11 routines::{BlueprintStrategy, Routine as MatmulRoutine, TilingArgs},
12};
13
14use crate::components::ConvolutionOperation;
15use crate::definition::ConvBlueprint;
16
17fn blueprint_operation(blueprint: &ConvBlueprint) -> ConvolutionOperation {
18 match blueprint {
19 ConvBlueprint::Forward(_) => ConvolutionOperation::Forward,
20 ConvBlueprint::BackwardData(_) => ConvolutionOperation::BackwardData,
21 ConvBlueprint::BackwardWeight(_) => ConvolutionOperation::BackwardWeight,
22 }
23}
24
25use crate::{
26 components::{ConvSetupError, global::args::RuntimeArgs},
27 kernels::{backward_data, backward_weight, forward},
28 launch::{
29 ConvAlgorithm, ConvolutionArgs, ConvolutionInputs, Strategy, strategy::AcceleratedTileKind,
30 },
31 routines::{
32 Routine,
33 simple::{
34 SimpleAsyncCyclicConv, SimpleAsyncStridedConv, SimpleAsyncTmaConv,
35 SimpleSyncCyclicConv, SimpleSyncStridedConv, SimpleSyncTilewiseConv,
36 },
37 specialized::{
38 SpecializedAsyncCyclicConv, SpecializedAsyncStridedConv, SpecializedTmaConv,
39 },
40 },
41};
42
43pub(crate) fn tile_kind_to_dispatch(kind: AcceleratedTileKind) -> TileMatmulKind {
45 match kind {
46 AcceleratedTileKind::Cmma => TileMatmulKind::Cmma,
47 AcceleratedTileKind::Mma => TileMatmulKind::Mma,
48 }
49}
50
51#[allow(clippy::result_large_err)]
57pub fn launch_ref<R: Runtime, const N_SPATIAL: usize>(
58 strategy: &Strategy,
59 client: &ComputeClient<R>,
60 inputs: ConvolutionInputs<R>,
61 args: ConvolutionArgs<N_SPATIAL>,
62 dtypes: MatmulElems,
63) -> Result<(), ConvSetupError> {
64 let (algorithm, tile_kind, forced_matmul) = match strategy {
65 Strategy::Inferred {
66 algorithm,
67 tile_kind,
68 } => (*algorithm, *tile_kind, None),
69 Strategy::Forced {
70 algorithm,
71 blueprint,
72 } => {
73 debug_assert_eq!(
74 inputs.operation(),
75 blueprint_operation(blueprint),
76 "Strategy::Forced blueprint variant does not match the inputs operation",
77 );
78 let matmul = blueprint.matmul().clone();
79 (*algorithm, AcceleratedTileKind::Cmma, Some(matmul))
83 }
84 };
85
86 if inputs.operation() == ConvolutionOperation::BackwardData
88 && algorithm == ConvAlgorithm::SimpleAsyncTma
89 {
90 return Err(crate::kernels::backward_data::launch::unsupported_tma_error());
91 }
92
93 dispatch_routine::<R, N_SPATIAL>(
94 algorithm,
95 tile_kind,
96 forced_matmul,
97 client,
98 inputs,
99 args,
100 dtypes,
101 )
102}
103
104#[allow(clippy::result_large_err, clippy::too_many_arguments)]
107fn dispatch_routine<R: Runtime, const N_SPATIAL: usize>(
108 algorithm: ConvAlgorithm,
109 tile_kind: AcceleratedTileKind,
110 forced_matmul: Option<TilingBlueprint>,
111 client: &ComputeClient<R>,
112 inputs: ConvolutionInputs<R>,
113 args: ConvolutionArgs<N_SPATIAL>,
114 dtypes: MatmulElems,
115) -> Result<(), ConvSetupError> {
116 let kind = tile_kind_to_dispatch(tile_kind);
117 match algorithm {
118 ConvAlgorithm::SimpleSyncCyclic => dispatch_inputs::<R, N_SPATIAL, SimpleSyncCyclicConv>(
119 client,
120 inputs,
121 args,
122 kind,
123 forced_matmul,
124 dtypes,
125 ),
126 ConvAlgorithm::SimpleSyncStrided => dispatch_inputs::<R, N_SPATIAL, SimpleSyncStridedConv>(
127 client,
128 inputs,
129 args,
130 kind,
131 forced_matmul,
132 dtypes,
133 ),
134 ConvAlgorithm::SimpleSyncTilewise => {
135 dispatch_inputs::<R, N_SPATIAL, SimpleSyncTilewiseConv>(
136 client,
137 inputs,
138 args,
139 kind,
140 forced_matmul,
141 dtypes,
142 )
143 }
144 ConvAlgorithm::SimpleAsyncCyclic => dispatch_inputs::<R, N_SPATIAL, SimpleAsyncCyclicConv>(
145 client,
146 inputs,
147 args,
148 kind,
149 forced_matmul,
150 dtypes,
151 ),
152 ConvAlgorithm::SimpleAsyncStrided => {
153 dispatch_inputs::<R, N_SPATIAL, SimpleAsyncStridedConv>(
154 client,
155 inputs,
156 args,
157 kind,
158 forced_matmul,
159 dtypes,
160 )
161 }
162 ConvAlgorithm::SimpleAsyncTma => dispatch_inputs::<R, N_SPATIAL, SimpleAsyncTmaConv>(
163 client,
164 inputs,
165 args,
166 kind,
167 forced_matmul,
168 dtypes,
169 ),
170 ConvAlgorithm::SpecializedAsyncCyclic => {
171 dispatch_inputs::<R, N_SPATIAL, SpecializedAsyncCyclicConv>(
172 client,
173 inputs,
174 args,
175 kind,
176 forced_matmul,
177 dtypes,
178 )
179 }
180 ConvAlgorithm::SpecializedAsyncStrided => {
181 dispatch_inputs::<R, N_SPATIAL, SpecializedAsyncStridedConv>(
182 client,
183 inputs,
184 args,
185 kind,
186 forced_matmul,
187 dtypes,
188 )
189 }
190 ConvAlgorithm::SpecializedTma => dispatch_inputs::<R, N_SPATIAL, SpecializedTmaConv>(
191 client,
192 inputs,
193 args,
194 kind,
195 forced_matmul,
196 dtypes,
197 ),
198 }
199}
200
201#[allow(clippy::result_large_err, clippy::too_many_arguments)]
207fn dispatch_inputs<R: Runtime, const N_SPATIAL: usize, Rt: Routine<Blueprint = TilingBlueprint>>(
208 client: &ComputeClient<R>,
209 inputs: ConvolutionInputs<R>,
210 args: ConvolutionArgs<N_SPATIAL>,
211 tile_matmul: TileMatmulKind,
212 forced_matmul: Option<TilingBlueprint>,
213 dtypes: MatmulElems,
214) -> Result<(), ConvSetupError>
215where
216 Rt::Args: forward::args::ConcreteArgs<Rt::MatmulRoutine>
217 + backward_data::args::ConcreteArgs<Rt::MatmulRoutine>
218 + backward_weight::args::ConcreteArgs<Rt::MatmulRoutine>,
219 Rt::Strategy: TilingArgs,
220{
221 let blueprint_strategy = build_blueprint_strategy::<Rt>(tile_matmul, forced_matmul);
222
223 match inputs {
224 ConvolutionInputs::Forward {
225 input,
226 weight,
227 bias,
228 out,
229 } => forward::launch::launch_internal::<R, N_SPATIAL, Rt>(
230 client,
231 input,
232 weight,
233 bias,
234 out,
235 args,
236 &blueprint_strategy,
237 dtypes,
238 ),
239 ConvolutionInputs::BackwardData {
240 out_grad,
241 weights,
242 in_grad,
243 } => backward_data::launch::launch_internal::<R, N_SPATIAL, Rt>(
244 client,
245 out_grad,
246 weights,
247 in_grad,
248 args,
249 &blueprint_strategy,
250 dtypes,
251 ),
252 ConvolutionInputs::BackwardWeight {
253 input,
254 out_grad,
255 weight_grad,
256 } => backward_weight::launch::launch_internal::<R, N_SPATIAL, Rt>(
257 client,
258 input,
259 out_grad,
260 weight_grad,
261 args,
262 &blueprint_strategy,
263 dtypes,
264 ),
265 }
266}
267
268fn build_blueprint_strategy<Rt: Routine<Blueprint = TilingBlueprint>>(
272 tile_matmul: TileMatmulKind,
273 forced_matmul: Option<TilingBlueprint>,
274) -> BlueprintStrategy<RuntimeArgs, Rt::MatmulRoutine>
275where
276 Rt::Strategy: TilingArgs,
277{
278 match forced_matmul {
279 Some(matmul) => BlueprintStrategy::Forced(matmul),
280 None => {
281 let mut s = <Rt::MatmulRoutine as MatmulRoutine<RuntimeArgs>>::Strategy::default();
282 s.set_tile_matmul(tile_matmul);
283 BlueprintStrategy::Inferred(s)
284 }
285 }
286}