cubek_convolution/kernels/backward_weight/
launch.rs

1use crate::{AcceleratedTileKind, ReadingStrategy};
2use crate::{
3    ConvolutionArgs, Strategy,
4    backward_weight::args::ConcreteArgs,
5    components::{ConvGemmConfig as _, ConvolutionOperation},
6    kernels::forward::simple::*,
7};
8use crate::{
9    components::ConvSetupError, kernels::backward_weight::selector::launch_kernel_concrete,
10};
11use crate::{
12    components::{ConvolutionProblem, Dimensionality},
13    kernels::forward::algorithm::Algorithm,
14};
15use cubecl::{
16    Runtime,
17    client::ComputeClient,
18    prelude::*,
19    std::{CubeOption, tensor::TensorHandle},
20};
21use cubek_matmul::definition::{AvailableLineSizes, MatmulElems, MatrixLayout};
22use cubek_matmul::launch::{MatmulInputHandle, MatmulInputHandleRef};
23use cubek_matmul::{
24    components::tile::{cmma::CmmaMatmul, io::Strided, mma::MmaMatmul},
25    definition,
26};
27use derive_new::new;
28
29macro_rules! with_tile_kind {
30    ($kind: expr, $T: ident, $launch: expr) => {
31        match $kind {
32            AcceleratedTileKind::Cmma => {
33                type $T = CmmaMatmul<CubeOption<Strided>>;
34                ($launch)()
35            }
36            AcceleratedTileKind::Mma => {
37                type $T = MmaMatmul<Strided, Strided, CubeOption<Strided>>;
38                ($launch)()
39            }
40        }
41    };
42}
43
44#[allow(clippy::result_large_err, clippy::too_many_arguments)]
45pub fn launch<R: Runtime, const N_SPATIAL: usize>(
46    strategy: &Strategy,
47    client: &ComputeClient<R>,
48    input: MatmulInputHandle<R>,
49    out_grad: MatmulInputHandle<R>,
50    weight_grad: TensorHandle<R>,
51    args: ConvolutionArgs<N_SPATIAL>,
52    dtypes: MatmulElems,
53) -> Result<(), ConvSetupError> {
54    launch_ref(
55        strategy,
56        client,
57        &input.as_ref(),
58        &out_grad.as_ref(),
59        &weight_grad.as_ref(),
60        args,
61        dtypes,
62    )
63}
64
65/// Perform an n-dimensional convolution using the implicit GEMM (im2col) algorithm, using cubecl
66/// tiling matmul components, using the specified algorithm.
67///
68/// * `input` - The input feature map, layout should be [batches, depth, height, width, in_channels]
69/// * `weight` - The weights (filter) applied to each kernel, layout should be [out_channels, kernel_d, kernel_h, kernel_w, in_channels]
70/// * `out` - The output feature map, layout should be [batches, out_depth, out_height, out_width, out_channels]
71/// * `bias` - The bias added to each out channel
72/// * `options` - The options to use for the convolution
73#[allow(clippy::result_large_err, clippy::too_many_arguments)]
74pub fn launch_ref<R: Runtime, const N_SPATIAL: usize>(
75    strategy: &Strategy,
76    client: &ComputeClient<R>,
77    input: &MatmulInputHandleRef<'_, R>,
78    out_grad: &MatmulInputHandleRef<'_, R>,
79    weight_grad: &TensorHandleRef<'_, R>,
80    args: ConvolutionArgs<N_SPATIAL>,
81    dtypes: MatmulElems,
82) -> Result<(), ConvSetupError> {
83    let backprop = BackwardsWeight::new(client, input, out_grad, weight_grad, args, dtypes);
84
85    match strategy {
86        Strategy::Simple {
87            read_strategy,
88            tile_kind,
89        } => with_tile_kind!(tile_kind, Accelerated, || match read_strategy {
90            ReadingStrategy::Cyclic => backprop.launch::<SimpleSyncCyclicConv<Accelerated>>(),
91            ReadingStrategy::Strided => backprop.launch::<SimpleSyncStridedConv<Accelerated>>(),
92            ReadingStrategy::Tilewise => backprop.launch::<SimpleSyncTilewiseConv<Accelerated>>(),
93            ReadingStrategy::AsyncCyclic => backprop.launch::<SimpleAsyncCyclicConv<Accelerated>>(),
94            ReadingStrategy::AsyncStrided =>
95                backprop.launch::<SimpleAsyncStridedConv<Accelerated>>(),
96            ReadingStrategy::Tma => backprop.launch::<SimpleAsyncTmaConv<Accelerated>>(),
97        }),
98    }
99}
100
101#[derive(new)]
102struct BackwardsWeight<'a, R: Runtime, const N_SPATIAL: usize> {
103    client: &'a ComputeClient<R>,
104    input: &'a MatmulInputHandleRef<'a, R>,
105    out_grad: &'a MatmulInputHandleRef<'a, R>,
106    weight_grad: &'a TensorHandleRef<'a, R>,
107    args: ConvolutionArgs<N_SPATIAL>,
108    dtypes: MatmulElems,
109}
110
111impl<'a, R: Runtime, const N_SPATIAL: usize> BackwardsWeight<'a, R, N_SPATIAL> {
112    fn launch<Alg: Algorithm>(self) -> Result<(), ConvSetupError>
113    where
114        Alg::Args: ConcreteArgs,
115    {
116        let ConvolutionArgs {
117            stride,
118            padding,
119            dilation,
120        } = self.args;
121
122        let dimensionality = match N_SPATIAL {
123            1 => Dimensionality::Dim1,
124            2 => Dimensionality::Dim2,
125            3 => Dimensionality::Dim3,
126            other => unimplemented!("Unsupported dimensionality {other}"),
127        };
128
129        launch_with_algorithm::<R, Alg>(
130            self.client,
131            self.input,
132            self.out_grad,
133            self.weight_grad,
134            (&stride, &padding, &dilation),
135            dimensionality,
136            self.dtypes,
137        )
138    }
139}
140
141#[allow(clippy::too_many_arguments)]
142fn launch_with_algorithm<R: Runtime, Alg: Algorithm>(
143    client: &ComputeClient<R>,
144    input: &MatmulInputHandleRef<'_, R>,
145    out_grad: &MatmulInputHandleRef<'_, R>,
146    weight_grad: &TensorHandleRef<'_, R>,
147    (stride, padding, dilation): (&[usize], &[usize], &[usize]),
148    dimensionality: Dimensionality,
149    dtypes: MatmulElems,
150) -> Result<(), ConvSetupError>
151where
152    Alg::Args: ConcreteArgs,
153{
154    let rank = input.data().shape.len();
155    let dim_c = rank - 1;
156
157    let n = input.shape()[0];
158    let c = input.shape()[dim_c];
159
160    let out_c = out_grad.shape()[dim_c];
161
162    let in_shape = &input.shape()[1..dim_c];
163    let kernel_shape = &weight_grad.shape[1..dim_c];
164    let out_shape = &out_grad.shape()[1..dim_c];
165
166    let op = ConvolutionOperation::BackwardWeight;
167
168    let input_data = Alg::into_tensor_handle(client, input.data(), dtypes.lhs_global, op)?;
169    let out_grad_data = Alg::into_tensor_handle(client, out_grad.data(), dtypes.rhs_global, op)?;
170
171    let mut input = *input;
172    let mut out_grad = *out_grad;
173
174    *input.data_mut() = input_data.as_ref();
175    *out_grad.data_mut() = out_grad_data.as_ref();
176
177    let problem = ConvolutionProblem {
178        m: out_c,
179        n: c * kernel_shape.iter().product::<usize>(),
180        k: n * out_shape.iter().product::<usize>(),
181        lhs_strides: input.data().strides.to_vec(),
182        rhs_strides: out_grad.data().strides.to_vec(),
183        lhs_layout: definition::MatrixLayout::ColMajor,
184        rhs_layout: definition::MatrixLayout::RowMajor,
185        kernel_size: kernel_shape.iter().map(|it| *it as u32).collect(),
186        stride: stride.iter().map(|it| *it as u32).collect(),
187        padding: padding.iter().map(|it| *it as i32).collect(),
188        dilation: dilation.iter().map(|it| *it as u32).collect(),
189
190        batches: n,
191        in_shape: in_shape.to_vec(),
192        out_shape: out_shape.to_vec(),
193        channels: c,
194        out_channels: out_c,
195
196        padded_channels: c,
197        operation: op,
198
199        dimensionality,
200        global_dtypes: dtypes.as_global_elems(),
201    };
202
203    launch_kernel::<R, Alg>(client, &input, &out_grad, weight_grad, problem, dtypes)
204}
205
206#[allow(clippy::result_large_err, clippy::too_many_arguments)]
207pub fn launch_kernel<R: Runtime, Alg: Algorithm>(
208    client: &ComputeClient<R>,
209    input: &MatmulInputHandleRef<'_, R>,
210    out_grad: &MatmulInputHandleRef<'_, R>,
211    weight_grad: &TensorHandleRef<'_, R>,
212    problem: ConvolutionProblem,
213    mut dtypes: MatmulElems,
214) -> Result<(), ConvSetupError>
215where
216    Alg::Args: ConcreteArgs,
217{
218    let plane_dim = client.properties().hardware.plane_size_max;
219    // Shape/strides are treated as k-major, with the last dim always being the contiguous one.
220    // So for the sake of selecting a line size, the shape/strides are always row-major.
221    let line_sizes = AvailableLineSizes::from_type_sizes(
222        client,
223        input.data().elem_size,
224        out_grad.data().elem_size,
225        weight_grad.elem_size,
226    )
227    .filter_lhs_with_tensor(
228        out_grad.data().strides,
229        out_grad.data().shape,
230        MatrixLayout::RowMajor,
231    )
232    .filter_rhs_with_tensor(
233        input.data().strides,
234        input.data().shape,
235        MatrixLayout::RowMajor,
236    )
237    .filter_out_with_tensor(weight_grad.strides, weight_grad.shape);
238
239    let line_sizes = Alg::filter_line_sizes(line_sizes).pick_max()?;
240
241    let selection = Alg::selection(client, &problem, plane_dim, &line_sizes, &mut dtypes)?;
242    let problem = Alg::Args::adjust_problem(client, problem, &selection, &dtypes);
243
244    let config = Alg::expand_config(
245        client.properties(),
246        &problem,
247        &selection,
248        &line_sizes,
249        &dtypes,
250    )?;
251
252    let line_sizes = config.line_sizes();
253
254    launch_kernel_concrete::<R, Alg>(
255        client,
256        input,
257        out_grad,
258        weight_grad,
259        problem,
260        line_sizes,
261        selection,
262        &dtypes,
263    )
264}