Skip to main content

cubek_convolution/kernels/backward_data/
launch.rs

1use crate::{
2    AcceleratedTileKind, ConvolutionArgs, ReadingStrategy, Strategy,
3    backward_data::args::ConcreteArgs,
4    components::{ConvolutionOperation, global::args::RuntimeArgs},
5    kernels::algorithm::simple::*,
6};
7use crate::{components::ConvSetupError, kernels::backward_data::selector::launch_kernel_concrete};
8use crate::{
9    components::{ConvolutionProblem, Dimensionality},
10    kernels::algorithm::Algorithm,
11};
12use cubecl::{
13    Runtime,
14    client::ComputeClient,
15    prelude::*,
16    std::{CubeOption, tensor::TensorHandle},
17};
18use cubek_matmul::{
19    components::tile::{cmma::CmmaMatmul, io::Strided, mma::MmaMatmul},
20    definition::{AvailableLineSizes, MatmulElems, MatmulSetupError, MatrixLayout},
21    launch::{MatmulInputHandle, MatmulInputHandleRef},
22    routines::BlueprintStrategy,
23};
24use derive_new::new;
25
26macro_rules! with_tile_kind {
27    ($kind: expr, $T: ident, $launch: expr) => {
28        match $kind {
29            AcceleratedTileKind::Cmma => {
30                type $T = CmmaMatmul<CubeOption<Strided>>;
31                ($launch)()
32            }
33            AcceleratedTileKind::Mma => {
34                type $T = MmaMatmul<Strided, Strided, CubeOption<Strided>>;
35                ($launch)()
36            }
37        }
38    };
39}
40
41#[allow(clippy::result_large_err, clippy::too_many_arguments)]
42pub fn launch<R: Runtime, const N_SPATIAL: usize>(
43    strategy: &Strategy,
44    client: &ComputeClient<R>,
45    out_grad: MatmulInputHandle<R>,
46    weights: MatmulInputHandle<R>,
47    in_grad: TensorHandle<R>,
48    args: ConvolutionArgs<N_SPATIAL>,
49    dtypes: MatmulElems,
50) -> Result<(), ConvSetupError> {
51    launch_ref(
52        strategy,
53        client,
54        &out_grad.as_ref(),
55        &weights.as_ref(),
56        &in_grad.as_ref(),
57        args,
58        dtypes,
59    )
60}
61
62/// Perform an n-dimensional convolution using the implicit GEMM (im2col) algorithm, using cubecl
63/// tiling matmul components, using the specified algorithm.
64///
65/// * `input` - The input feature map, layout should be [batches, depth, height, width, in_channels]
66/// * `weight` - The weights (filter) applied to each kernel, layout should be [out_channels, kernel_d, kernel_h, kernel_w, in_channels]
67/// * `out` - The output feature map, layout should be [batches, out_depth, out_height, out_width, out_channels]
68/// * `bias` - The bias added to each out channel
69/// * `options` - The options to use for the convolution
70#[allow(clippy::result_large_err, clippy::too_many_arguments)]
71pub fn launch_ref<R: Runtime, const N_SPATIAL: usize>(
72    strategy: &Strategy,
73    client: &ComputeClient<R>,
74    out_grad: &MatmulInputHandleRef<'_, R>,
75    weights: &MatmulInputHandleRef<'_, R>,
76    in_grad: &TensorHandleRef<'_, R>,
77    args: ConvolutionArgs<N_SPATIAL>,
78    dtypes: MatmulElems,
79) -> Result<(), ConvSetupError> {
80    let backprop = BackwardsData::new(client, out_grad, weights, in_grad, args, dtypes);
81
82    match strategy {
83        Strategy::Simple {
84            read_strategy,
85            tile_kind,
86        } => with_tile_kind!(tile_kind, Accelerated, || match read_strategy {
87            ReadingStrategy::Cyclic => backprop.launch::<SimpleSyncCyclicConv<Accelerated>>(),
88            ReadingStrategy::Strided => backprop.launch::<SimpleSyncStridedConv<Accelerated>>(),
89            ReadingStrategy::Tilewise => backprop.launch::<SimpleSyncTilewiseConv<Accelerated>>(),
90            ReadingStrategy::AsyncCyclic => backprop.launch::<SimpleAsyncCyclicConv<Accelerated>>(),
91            ReadingStrategy::AsyncStrided =>
92                backprop.launch::<SimpleAsyncStridedConv<Accelerated>>(),
93            ReadingStrategy::Tma => Err(ConvSetupError::Matmul(MatmulSetupError::InvalidConfig(
94                Box::new("Data backprop doesn't yet work with current TMA tiling strategy")
95            ))),
96        }),
97    }
98}
99
100#[derive(new)]
101struct BackwardsData<'a, R: Runtime, const N_SPATIAL: usize> {
102    client: &'a ComputeClient<R>,
103    out_grad: &'a MatmulInputHandleRef<'a, R>,
104    weights: &'a MatmulInputHandleRef<'a, R>,
105    in_grad: &'a TensorHandleRef<'a, R>,
106    args: ConvolutionArgs<N_SPATIAL>,
107    dtypes: MatmulElems,
108}
109
110impl<'a, R: Runtime, const N_SPATIAL: usize> BackwardsData<'a, R, N_SPATIAL> {
111    fn launch<Alg: Algorithm>(self) -> Result<(), ConvSetupError>
112    where
113        Alg::Args: ConcreteArgs<Alg::Routine>,
114    {
115        let ConvolutionArgs {
116            stride,
117            padding,
118            dilation,
119        } = self.args;
120
121        let dimensionality = match N_SPATIAL {
122            1 => Dimensionality::Dim1,
123            2 => Dimensionality::Dim2,
124            3 => Dimensionality::Dim3,
125            other => unimplemented!("Unsupported dimensionality {other}"),
126        };
127
128        launch_with_algorithm::<R, Alg>(
129            self.client,
130            self.out_grad,
131            self.weights,
132            self.in_grad,
133            (&stride, &padding, &dilation),
134            dimensionality,
135            &BlueprintStrategy::Inferred(Default::default()),
136            self.dtypes,
137        )
138    }
139}
140
141#[allow(clippy::too_many_arguments)]
142fn launch_with_algorithm<R: Runtime, Alg: Algorithm>(
143    client: &ComputeClient<R>,
144    out_grad: &MatmulInputHandleRef<'_, R>,
145    weights: &MatmulInputHandleRef<'_, R>,
146    in_grad: &TensorHandleRef<'_, R>,
147    (stride, padding, dilation): (&[usize], &[usize], &[usize]),
148    dimensionality: Dimensionality,
149    blueprint_strategy: &BlueprintStrategy<RuntimeArgs, Alg::Routine>,
150    dtypes: MatmulElems,
151) -> Result<(), ConvSetupError>
152where
153    Alg::Args: ConcreteArgs<Alg::Routine>,
154{
155    let rank = in_grad.shape.len();
156    let dim_c = rank - 1;
157
158    let n = in_grad.shape[0];
159    let c = in_grad.shape[dim_c];
160
161    let out_c = out_grad.shape()[dim_c];
162
163    let in_shape = &in_grad.shape[1..dim_c];
164    let kernel_shape = &weights.shape()[1..dim_c];
165    let out_shape = &out_grad.shape()[1..dim_c];
166
167    let op = ConvolutionOperation::BackwardData;
168
169    let out_grad_data = Alg::into_tensor_handle(client, out_grad.data(), dtypes.lhs_global, op)?;
170    let weights_data = Alg::into_tensor_handle(client, weights.data(), dtypes.rhs_global, op)?;
171
172    let mut out_grad = *out_grad;
173    let mut weights = *weights;
174
175    *out_grad.data_mut() = out_grad_data.as_ref();
176    *weights.data_mut() = weights_data.as_ref();
177
178    let problem = ConvolutionProblem {
179        m: n * in_shape.iter().product::<usize>(),
180        n: c,
181        k: out_c * kernel_shape.iter().product::<usize>(),
182
183        lhs_strides: out_grad.data().strides.to_vec(),
184        rhs_strides: weights.data().strides.to_vec(),
185        lhs_layout: MatrixLayout::RowMajor,
186        rhs_layout: MatrixLayout::RowMajor,
187        kernel_size: kernel_shape.iter().map(|it| *it as u32).collect(),
188        stride: stride.iter().map(|it| *it as u32).collect(),
189        padding: padding.iter().map(|it| *it as i32).collect(),
190        dilation: dilation.iter().map(|it| *it as u32).collect(),
191
192        batches: n,
193        in_shape: in_shape.to_vec(),
194        out_shape: out_shape.to_vec(),
195        channels: c,
196        out_channels: out_c,
197
198        padded_channels: out_c,
199        operation: op,
200
201        dimensionality,
202        global_dtypes: dtypes.as_global_elems(),
203    };
204
205    launch_kernel::<R, Alg>(
206        client,
207        &out_grad,
208        &weights,
209        in_grad,
210        problem,
211        blueprint_strategy,
212        dtypes,
213    )
214}
215
216#[allow(clippy::result_large_err, clippy::too_many_arguments)]
217pub fn launch_kernel<R: Runtime, Alg: Algorithm>(
218    client: &ComputeClient<R>,
219    out_grad: &MatmulInputHandleRef<'_, R>,
220    weights: &MatmulInputHandleRef<'_, R>,
221    in_grad: &TensorHandleRef<'_, R>,
222    problem: ConvolutionProblem,
223    blueprint_strategy: &BlueprintStrategy<RuntimeArgs, Alg::Routine>,
224    dtypes: MatmulElems,
225) -> Result<(), ConvSetupError>
226where
227    Alg::Args: ConcreteArgs<Alg::Routine>,
228{
229    // Shape/strides are treated as k-major, with the last dim always being the contiguous one.
230    // So for the sake of selecting a line size, the shape/strides are always row-major.
231    let line_sizes = AvailableLineSizes::from_type_sizes(
232        client,
233        out_grad.data().elem_size,
234        weights.data().elem_size,
235        in_grad.elem_size,
236    )
237    .filter_lhs_with_tensor(
238        out_grad.data().strides,
239        out_grad.data().shape,
240        MatrixLayout::RowMajor,
241    )
242    .filter_rhs_with_tensor(
243        weights.data().strides,
244        weights.data().shape,
245        MatrixLayout::RowMajor,
246    )
247    .filter_out_with_tensor(in_grad.strides, in_grad.shape);
248
249    let line_sizes = Alg::filter_line_sizes(line_sizes).pick_max()?;
250
251    launch_kernel_concrete::<R, Alg::Args, Alg::Routine>(
252        client,
253        out_grad,
254        weights,
255        in_grad,
256        problem,
257        line_sizes,
258        blueprint_strategy,
259        &dtypes,
260    )
261}