cubek_convolution/kernels/backward_data/
launch.rs

1use crate::{
2    AcceleratedTileKind, ConvolutionArgs, ReadingStrategy, Strategy,
3    backward_data::args::ConcreteArgs,
4    components::{ConvGemmConfig as _, ConvolutionOperation},
5    kernels::forward::simple::*,
6};
7use crate::{components::ConvSetupError, kernels::backward_data::selector::launch_kernel_concrete};
8use crate::{
9    components::{ConvolutionProblem, Dimensionality},
10    kernels::forward::algorithm::Algorithm,
11};
12use cubecl::{
13    Runtime,
14    client::ComputeClient,
15    prelude::*,
16    std::{CubeOption, tensor::TensorHandle},
17};
18use cubek_matmul::{
19    components::tile::{cmma::CmmaMatmul, io::Strided, mma::MmaMatmul},
20    definition::{AvailableLineSizes, MatmulElems, MatmulSetupError, MatrixLayout},
21    launch::{MatmulInputHandle, MatmulInputHandleRef},
22};
23use derive_new::new;
24
25macro_rules! with_tile_kind {
26    ($kind: expr, $T: ident, $launch: expr) => {
27        match $kind {
28            AcceleratedTileKind::Cmma => {
29                type $T = CmmaMatmul<CubeOption<Strided>>;
30                ($launch)()
31            }
32            AcceleratedTileKind::Mma => {
33                type $T = MmaMatmul<Strided, Strided, CubeOption<Strided>>;
34                ($launch)()
35            }
36        }
37    };
38}
39
40#[allow(clippy::result_large_err, clippy::too_many_arguments)]
41pub fn launch<R: Runtime, const N_SPATIAL: usize>(
42    strategy: &Strategy,
43    client: &ComputeClient<R>,
44    out_grad: MatmulInputHandle<R>,
45    weights: MatmulInputHandle<R>,
46    in_grad: TensorHandle<R>,
47    args: ConvolutionArgs<N_SPATIAL>,
48    dtypes: MatmulElems,
49) -> Result<(), ConvSetupError> {
50    launch_ref(
51        strategy,
52        client,
53        &out_grad.as_ref(),
54        &weights.as_ref(),
55        &in_grad.as_ref(),
56        args,
57        dtypes,
58    )
59}
60
61/// Perform an n-dimensional convolution using the implicit GEMM (im2col) algorithm, using cubecl
62/// tiling matmul components, using the specified algorithm.
63///
64/// * `input` - The input feature map, layout should be [batches, depth, height, width, in_channels]
65/// * `weight` - The weights (filter) applied to each kernel, layout should be [out_channels, kernel_d, kernel_h, kernel_w, in_channels]
66/// * `out` - The output feature map, layout should be [batches, out_depth, out_height, out_width, out_channels]
67/// * `bias` - The bias added to each out channel
68/// * `options` - The options to use for the convolution
69#[allow(clippy::result_large_err, clippy::too_many_arguments)]
70pub fn launch_ref<R: Runtime, const N_SPATIAL: usize>(
71    strategy: &Strategy,
72    client: &ComputeClient<R>,
73    out_grad: &MatmulInputHandleRef<'_, R>,
74    weights: &MatmulInputHandleRef<'_, R>,
75    in_grad: &TensorHandleRef<'_, R>,
76    args: ConvolutionArgs<N_SPATIAL>,
77    dtypes: MatmulElems,
78) -> Result<(), ConvSetupError> {
79    let backprop = BackwardsData::new(client, out_grad, weights, in_grad, args, dtypes);
80
81    match strategy {
82        Strategy::Simple {
83            read_strategy,
84            tile_kind,
85        } => with_tile_kind!(tile_kind, Accelerated, || match read_strategy {
86            ReadingStrategy::Cyclic => backprop.launch::<SimpleSyncCyclicConv<Accelerated>>(),
87            ReadingStrategy::Strided => backprop.launch::<SimpleSyncStridedConv<Accelerated>>(),
88            ReadingStrategy::Tilewise => backprop.launch::<SimpleSyncTilewiseConv<Accelerated>>(),
89            ReadingStrategy::AsyncCyclic => backprop.launch::<SimpleAsyncCyclicConv<Accelerated>>(),
90            ReadingStrategy::AsyncStrided =>
91                backprop.launch::<SimpleAsyncStridedConv<Accelerated>>(),
92            ReadingStrategy::Tma => Err(ConvSetupError::Matmul(MatmulSetupError::InvalidConfig(
93                Box::new("Data backprop doesn't yet work with current TMA tiling strategy")
94            ))),
95        }),
96    }
97}
98
99#[derive(new)]
100struct BackwardsData<'a, R: Runtime, const N_SPATIAL: usize> {
101    client: &'a ComputeClient<R>,
102    out_grad: &'a MatmulInputHandleRef<'a, R>,
103    weights: &'a MatmulInputHandleRef<'a, R>,
104    in_grad: &'a TensorHandleRef<'a, R>,
105    args: ConvolutionArgs<N_SPATIAL>,
106    dtypes: MatmulElems,
107}
108
109impl<'a, R: Runtime, const N_SPATIAL: usize> BackwardsData<'a, R, N_SPATIAL> {
110    fn launch<Alg: Algorithm>(self) -> Result<(), ConvSetupError>
111    where
112        Alg::Args: ConcreteArgs,
113    {
114        let ConvolutionArgs {
115            stride,
116            padding,
117            dilation,
118        } = self.args;
119
120        let dimensionality = match N_SPATIAL {
121            1 => Dimensionality::Dim1,
122            2 => Dimensionality::Dim2,
123            3 => Dimensionality::Dim3,
124            other => unimplemented!("Unsupported dimensionality {other}"),
125        };
126
127        launch_with_algorithm::<R, Alg>(
128            self.client,
129            self.out_grad,
130            self.weights,
131            self.in_grad,
132            (&stride, &padding, &dilation),
133            dimensionality,
134            self.dtypes,
135        )
136    }
137}
138
139#[allow(clippy::too_many_arguments)]
140fn launch_with_algorithm<R: Runtime, Alg: Algorithm>(
141    client: &ComputeClient<R>,
142    out_grad: &MatmulInputHandleRef<'_, R>,
143    weights: &MatmulInputHandleRef<'_, R>,
144    in_grad: &TensorHandleRef<'_, R>,
145    (stride, padding, dilation): (&[usize], &[usize], &[usize]),
146    dimensionality: Dimensionality,
147    dtypes: MatmulElems,
148) -> Result<(), ConvSetupError>
149where
150    Alg::Args: ConcreteArgs,
151{
152    let rank = in_grad.shape.len();
153    let dim_c = rank - 1;
154
155    let n = in_grad.shape[0];
156    let c = in_grad.shape[dim_c];
157
158    let out_c = out_grad.shape()[dim_c];
159
160    let in_shape = &in_grad.shape[1..dim_c];
161    let kernel_shape = &weights.shape()[1..dim_c];
162    let out_shape = &out_grad.shape()[1..dim_c];
163
164    let op = ConvolutionOperation::BackwardData;
165
166    let out_grad_data = Alg::into_tensor_handle(client, out_grad.data(), dtypes.lhs_global, op)?;
167    let weights_data = Alg::into_tensor_handle(client, weights.data(), dtypes.rhs_global, op)?;
168
169    let mut out_grad = *out_grad;
170    let mut weights = *weights;
171
172    *out_grad.data_mut() = out_grad_data.as_ref();
173    *weights.data_mut() = weights_data.as_ref();
174
175    let problem = ConvolutionProblem {
176        m: n * in_shape.iter().product::<usize>(),
177        n: c,
178        k: out_c * kernel_shape.iter().product::<usize>(),
179
180        lhs_strides: out_grad.data().strides.to_vec(),
181        rhs_strides: weights.data().strides.to_vec(),
182        lhs_layout: MatrixLayout::RowMajor,
183        rhs_layout: MatrixLayout::RowMajor,
184        kernel_size: kernel_shape.iter().map(|it| *it as u32).collect(),
185        stride: stride.iter().map(|it| *it as u32).collect(),
186        padding: padding.iter().map(|it| *it as i32).collect(),
187        dilation: dilation.iter().map(|it| *it as u32).collect(),
188
189        batches: n,
190        in_shape: in_shape.to_vec(),
191        out_shape: out_shape.to_vec(),
192        channels: c,
193        out_channels: out_c,
194
195        padded_channels: out_c,
196        operation: op,
197
198        dimensionality,
199        global_dtypes: dtypes.as_global_elems(),
200    };
201
202    launch_kernel::<R, Alg>(client, &out_grad, &weights, in_grad, problem, dtypes)
203}
204
205#[allow(clippy::result_large_err, clippy::too_many_arguments)]
206pub fn launch_kernel<R: Runtime, Alg: Algorithm>(
207    client: &ComputeClient<R>,
208    out_grad: &MatmulInputHandleRef<'_, R>,
209    weights: &MatmulInputHandleRef<'_, R>,
210    in_grad: &TensorHandleRef<'_, R>,
211    problem: ConvolutionProblem,
212    mut dtypes: MatmulElems,
213) -> Result<(), ConvSetupError>
214where
215    Alg::Args: ConcreteArgs,
216{
217    let plane_dim = client.properties().hardware.plane_size_max;
218    // Shape/strides are treated as k-major, with the last dim always being the contiguous one.
219    // So for the sake of selecting a line size, the shape/strides are always row-major.
220    let line_sizes = AvailableLineSizes::from_type_sizes(
221        client,
222        out_grad.data().elem_size,
223        weights.data().elem_size,
224        in_grad.elem_size,
225    )
226    .filter_lhs_with_tensor(
227        out_grad.data().strides,
228        out_grad.data().shape,
229        MatrixLayout::RowMajor,
230    )
231    .filter_rhs_with_tensor(
232        weights.data().strides,
233        weights.data().shape,
234        MatrixLayout::RowMajor,
235    )
236    .filter_out_with_tensor(in_grad.strides, in_grad.shape);
237
238    let line_sizes = Alg::filter_line_sizes(line_sizes).pick_max()?;
239
240    let selection = Alg::selection(client, &problem, plane_dim, &line_sizes, &mut dtypes)?;
241    let problem = Alg::Args::adjust_problem(client, problem, &selection, &dtypes);
242
243    let device_props = client.properties();
244    let config = Alg::expand_config(device_props, &problem, &selection, &line_sizes, &dtypes)?;
245
246    let line_sizes = config.line_sizes();
247
248    launch_kernel_concrete::<R, Alg>(
249        client, out_grad, weights, in_grad, problem, line_sizes, selection, &dtypes,
250    )
251}