Skip to main content

cubek_convolution/launch/
base.rs

1//! Unified `launch_ref` entry point for the convolution kernel family.
2//!
3//! Picks the right `Routine` impl from `ConvAlgorithm`, threads the
4//! `Strategy` (`Specific` / `Forced`) into a matmul `BlueprintStrategy`, and
5//! dispatches to the per-operation helper based on `ConvolutionInputs`.
6
7use cubecl::{Runtime, client::ComputeClient};
8use cubek_matmul::{
9    components::tile::TileMatmulKind,
10    definition::{MatmulElems, TilingBlueprint},
11    routines::{BlueprintStrategy, Routine as MatmulRoutine, TilingArgs},
12};
13
14use crate::components::ConvolutionOperation;
15use crate::definition::ConvBlueprint;
16
17fn blueprint_operation(blueprint: &ConvBlueprint) -> ConvolutionOperation {
18    match blueprint {
19        ConvBlueprint::Forward(_) => ConvolutionOperation::Forward,
20        ConvBlueprint::BackwardData(_) => ConvolutionOperation::BackwardData,
21        ConvBlueprint::BackwardWeight(_) => ConvolutionOperation::BackwardWeight,
22    }
23}
24
25use crate::{
26    components::{ConvSetupError, global::args::RuntimeArgs},
27    kernels::{backward_data, backward_weight, forward},
28    launch::{
29        ConvAlgorithm, ConvolutionArgs, ConvolutionInputs, Strategy, strategy::AcceleratedTileKind,
30    },
31    routines::{
32        Routine,
33        simple::{
34            SimpleAsyncCyclicConv, SimpleAsyncStridedConv, SimpleAsyncTmaConv,
35            SimpleSyncCyclicConv, SimpleSyncStridedConv, SimpleSyncTilewiseConv,
36        },
37        specialized::{
38            SpecializedAsyncCyclicConv, SpecializedAsyncStridedConv, SpecializedTmaConv,
39        },
40    },
41};
42
43/// Map `AcceleratedTileKind` → matmul's `TileMatmulKind`.
44pub(crate) fn tile_kind_to_dispatch(kind: AcceleratedTileKind) -> TileMatmulKind {
45    match kind {
46        AcceleratedTileKind::Cmma => TileMatmulKind::Cmma,
47        AcceleratedTileKind::Mma => TileMatmulKind::Mma,
48    }
49}
50
51/// The single public convolution entry point.
52///
53/// Routes the `inputs` (whose discriminant is the operation) and `strategy`
54/// (algorithm + tile-matmul kind, optionally a forced blueprint) into the right
55/// generic `Routine` and per-operation launch helper.
56#[allow(clippy::result_large_err)]
57pub fn launch_ref<R: Runtime, const N_SPATIAL: usize>(
58    strategy: &Strategy,
59    client: &ComputeClient<R>,
60    inputs: ConvolutionInputs<R>,
61    args: ConvolutionArgs<N_SPATIAL>,
62    dtypes: MatmulElems,
63) -> Result<(), ConvSetupError> {
64    let (algorithm, tile_kind, forced_matmul) = match strategy {
65        Strategy::Inferred {
66            algorithm,
67            tile_kind,
68        } => (*algorithm, *tile_kind, None),
69        Strategy::Forced {
70            algorithm,
71            blueprint,
72        } => {
73            debug_assert_eq!(
74                inputs.operation(),
75                blueprint_operation(blueprint),
76                "Strategy::Forced blueprint variant does not match the inputs operation",
77            );
78            let matmul = blueprint.matmul().clone();
79            // For Forced, tile_kind is encoded inside the matmul blueprint, so
80            // the explicit tile_kind here is unused; we pass Cmma as a benign
81            // default (it gets overwritten by the forced blueprint).
82            (*algorithm, AcceleratedTileKind::Cmma, Some(matmul))
83        }
84    };
85
86    // Backward-data does not currently support the TMA reading strategy.
87    if inputs.operation() == ConvolutionOperation::BackwardData
88        && algorithm == ConvAlgorithm::SimpleAsyncTma
89    {
90        return Err(crate::kernels::backward_data::launch::unsupported_tma_error());
91    }
92
93    dispatch_routine::<R, N_SPATIAL>(
94        algorithm,
95        tile_kind,
96        forced_matmul,
97        client,
98        inputs,
99        args,
100        dtypes,
101    )
102}
103
104/// Dispatch on `ConvAlgorithm` to instantiate the right concrete `Routine`
105/// generic, then forward to the per-operation helper.
106#[allow(clippy::result_large_err, clippy::too_many_arguments)]
107fn dispatch_routine<R: Runtime, const N_SPATIAL: usize>(
108    algorithm: ConvAlgorithm,
109    tile_kind: AcceleratedTileKind,
110    forced_matmul: Option<TilingBlueprint>,
111    client: &ComputeClient<R>,
112    inputs: ConvolutionInputs<R>,
113    args: ConvolutionArgs<N_SPATIAL>,
114    dtypes: MatmulElems,
115) -> Result<(), ConvSetupError> {
116    let kind = tile_kind_to_dispatch(tile_kind);
117    match algorithm {
118        ConvAlgorithm::SimpleSyncCyclic => dispatch_inputs::<R, N_SPATIAL, SimpleSyncCyclicConv>(
119            client,
120            inputs,
121            args,
122            kind,
123            forced_matmul,
124            dtypes,
125        ),
126        ConvAlgorithm::SimpleSyncStrided => dispatch_inputs::<R, N_SPATIAL, SimpleSyncStridedConv>(
127            client,
128            inputs,
129            args,
130            kind,
131            forced_matmul,
132            dtypes,
133        ),
134        ConvAlgorithm::SimpleSyncTilewise => {
135            dispatch_inputs::<R, N_SPATIAL, SimpleSyncTilewiseConv>(
136                client,
137                inputs,
138                args,
139                kind,
140                forced_matmul,
141                dtypes,
142            )
143        }
144        ConvAlgorithm::SimpleAsyncCyclic => dispatch_inputs::<R, N_SPATIAL, SimpleAsyncCyclicConv>(
145            client,
146            inputs,
147            args,
148            kind,
149            forced_matmul,
150            dtypes,
151        ),
152        ConvAlgorithm::SimpleAsyncStrided => {
153            dispatch_inputs::<R, N_SPATIAL, SimpleAsyncStridedConv>(
154                client,
155                inputs,
156                args,
157                kind,
158                forced_matmul,
159                dtypes,
160            )
161        }
162        ConvAlgorithm::SimpleAsyncTma => dispatch_inputs::<R, N_SPATIAL, SimpleAsyncTmaConv>(
163            client,
164            inputs,
165            args,
166            kind,
167            forced_matmul,
168            dtypes,
169        ),
170        ConvAlgorithm::SpecializedAsyncCyclic => {
171            dispatch_inputs::<R, N_SPATIAL, SpecializedAsyncCyclicConv>(
172                client,
173                inputs,
174                args,
175                kind,
176                forced_matmul,
177                dtypes,
178            )
179        }
180        ConvAlgorithm::SpecializedAsyncStrided => {
181            dispatch_inputs::<R, N_SPATIAL, SpecializedAsyncStridedConv>(
182                client,
183                inputs,
184                args,
185                kind,
186                forced_matmul,
187                dtypes,
188            )
189        }
190        ConvAlgorithm::SpecializedTma => dispatch_inputs::<R, N_SPATIAL, SpecializedTmaConv>(
191            client,
192            inputs,
193            args,
194            kind,
195            forced_matmul,
196            dtypes,
197        ),
198    }
199}
200
201/// Branch on operation and forward to the per-op launcher.
202///
203/// All three per-op `ConcreteArgs` traits share the same name and the same
204/// blanket impls on `TensorArgs<RuntimeArgs>` / `TensorMapArgs<RuntimeArgs>`,
205/// so the where clause simply requires an impl per operation.
206#[allow(clippy::result_large_err, clippy::too_many_arguments)]
207fn dispatch_inputs<R: Runtime, const N_SPATIAL: usize, Rt: Routine<Blueprint = TilingBlueprint>>(
208    client: &ComputeClient<R>,
209    inputs: ConvolutionInputs<R>,
210    args: ConvolutionArgs<N_SPATIAL>,
211    tile_matmul: TileMatmulKind,
212    forced_matmul: Option<TilingBlueprint>,
213    dtypes: MatmulElems,
214) -> Result<(), ConvSetupError>
215where
216    Rt::Args: forward::args::ConcreteArgs<Rt::MatmulRoutine>
217        + backward_data::args::ConcreteArgs<Rt::MatmulRoutine>
218        + backward_weight::args::ConcreteArgs<Rt::MatmulRoutine>,
219    Rt::Strategy: TilingArgs,
220{
221    let blueprint_strategy = build_blueprint_strategy::<Rt>(tile_matmul, forced_matmul);
222
223    match inputs {
224        ConvolutionInputs::Forward {
225            input,
226            weight,
227            bias,
228            out,
229        } => forward::launch::launch_internal::<R, N_SPATIAL, Rt>(
230            client,
231            input,
232            weight,
233            bias,
234            out,
235            args,
236            &blueprint_strategy,
237            dtypes,
238        ),
239        ConvolutionInputs::BackwardData {
240            out_grad,
241            weights,
242            in_grad,
243        } => backward_data::launch::launch_internal::<R, N_SPATIAL, Rt>(
244            client,
245            out_grad,
246            weights,
247            in_grad,
248            args,
249            &blueprint_strategy,
250            dtypes,
251        ),
252        ConvolutionInputs::BackwardWeight {
253            input,
254            out_grad,
255            weight_grad,
256        } => backward_weight::launch::launch_internal::<R, N_SPATIAL, Rt>(
257            client,
258            input,
259            out_grad,
260            weight_grad,
261            args,
262            &blueprint_strategy,
263            dtypes,
264        ),
265    }
266}
267
268/// Build a matmul `BlueprintStrategy` from either a forced `TilingBlueprint`
269/// (extracted from `ConvBlueprint`) or an `Inferred` strategy stamped with the
270/// requested tile-matmul kind.
271fn build_blueprint_strategy<Rt: Routine<Blueprint = TilingBlueprint>>(
272    tile_matmul: TileMatmulKind,
273    forced_matmul: Option<TilingBlueprint>,
274) -> BlueprintStrategy<RuntimeArgs, Rt::MatmulRoutine>
275where
276    Rt::Strategy: TilingArgs,
277{
278    match forced_matmul {
279        Some(matmul) => BlueprintStrategy::Forced(matmul),
280        None => {
281            let mut s = <Rt::MatmulRoutine as MatmulRoutine<RuntimeArgs>>::Strategy::default();
282            s.set_tile_matmul(tile_matmul);
283            BlueprintStrategy::Inferred(s)
284        }
285    }
286}