Skip to main content

cubek_convolution/kernels/backward_data/
launch.rs

1use crate::{
2    AcceleratedTileKind, ConvolutionArgs, ReadingStrategy, Strategy,
3    backward_data::args::ConcreteArgs,
4    components::{ConvolutionOperation, global::args::RuntimeArgs},
5    kernels::algorithm::simple::*,
6};
7use crate::{components::ConvSetupError, kernels::backward_data::selector::launch_kernel_concrete};
8use crate::{
9    components::{ConvolutionProblem, Dimensionality},
10    kernels::algorithm::Algorithm,
11};
12use cubecl::{Runtime, client::ComputeClient, prelude::*};
13use cubek_matmul::{
14    components::tile_matmul::DispatchTileMatmul,
15    definition::{AvailableVectorSizes, MatmulElems, MatmulSetupError},
16    routines::{BlueprintStrategy, Routine, TilingArgs},
17};
18use cubek_std::{InputBinding, MatrixLayout};
19use derive_new::new;
20
21fn tile_kind_to_dispatch(kind: &AcceleratedTileKind) -> DispatchTileMatmul {
22    match kind {
23        AcceleratedTileKind::Cmma => DispatchTileMatmul::Cmma,
24        AcceleratedTileKind::Mma => DispatchTileMatmul::Mma,
25    }
26}
27
28/// Perform an n-dimensional convolution using the implicit GEMM (im2col) algorithm, using cubecl
29/// tiling matmul components, using the specified algorithm.
30///
31/// * `input` - The input feature map, layout should be [batches, depth, height, width, in_channels]
32/// * `weight` - The weights (filter) applied to each kernel, layout should be [out_channels, kernel_d, kernel_h, kernel_w, in_channels]
33/// * `out` - The output feature map, layout should be [batches, out_depth, out_height, out_width, out_channels]
34/// * `bias` - The bias added to each out channel
35/// * `options` - The options to use for the convolution
36#[allow(clippy::result_large_err, clippy::too_many_arguments)]
37pub fn launch_ref<R: Runtime, const N_SPATIAL: usize>(
38    strategy: &Strategy,
39    client: &ComputeClient<R>,
40    out_grad: InputBinding<R>,
41    weights: InputBinding<R>,
42    in_grad: TensorBinding<R>,
43    args: ConvolutionArgs<N_SPATIAL>,
44    dtypes: MatmulElems,
45) -> Result<(), ConvSetupError> {
46    let backprop = BackwardsData::new(client, out_grad, weights, in_grad, args, dtypes);
47
48    match strategy {
49        Strategy::Simple {
50            read_strategy,
51            tile_kind,
52        } => {
53            let kind = tile_kind_to_dispatch(tile_kind);
54            match read_strategy {
55                ReadingStrategy::Cyclic => backprop.launch::<SimpleSyncCyclicConv>(kind),
56                ReadingStrategy::Strided => backprop.launch::<SimpleSyncStridedConv>(kind),
57                ReadingStrategy::Tilewise => backprop.launch::<SimpleSyncTilewiseConv>(kind),
58                ReadingStrategy::AsyncCyclic => backprop.launch::<SimpleAsyncCyclicConv>(kind),
59                ReadingStrategy::AsyncStrided => backprop.launch::<SimpleAsyncStridedConv>(kind),
60                ReadingStrategy::Tma => {
61                    Err(ConvSetupError::Matmul(MatmulSetupError::InvalidConfig(
62                        Box::new("Data backprop doesn't yet work with current TMA tiling strategy"),
63                    )))
64                }
65            }
66        }
67    }
68}
69
70#[derive(new)]
71struct BackwardsData<'a, R: Runtime, const N_SPATIAL: usize> {
72    client: &'a ComputeClient<R>,
73    out_grad: InputBinding<R>,
74    weights: InputBinding<R>,
75    in_grad: TensorBinding<R>,
76    args: ConvolutionArgs<N_SPATIAL>,
77    dtypes: MatmulElems,
78}
79
80impl<'a, R: Runtime, const N_SPATIAL: usize> BackwardsData<'a, R, N_SPATIAL> {
81    fn launch<Alg: Algorithm>(self, tile_matmul: DispatchTileMatmul) -> Result<(), ConvSetupError>
82    where
83        Alg::Args: ConcreteArgs<Alg::Routine>,
84        <Alg::Routine as Routine<RuntimeArgs>>::Strategy: TilingArgs,
85    {
86        let ConvolutionArgs {
87            stride,
88            padding,
89            dilation,
90        } = self.args;
91
92        let dimensionality = match N_SPATIAL {
93            1 => Dimensionality::Dim1,
94            2 => Dimensionality::Dim2,
95            3 => Dimensionality::Dim3,
96            other => unimplemented!("Unsupported dimensionality {other}"),
97        };
98
99        let mut args = <Alg::Routine as Routine<RuntimeArgs>>::Strategy::default();
100        args.set_tile_matmul(tile_matmul);
101
102        launch_with_algorithm::<R, Alg>(
103            self.client,
104            self.out_grad,
105            self.weights,
106            self.in_grad,
107            (&stride, &padding, &dilation),
108            dimensionality,
109            &BlueprintStrategy::Inferred(args),
110            self.dtypes,
111        )
112    }
113}
114
115#[allow(clippy::too_many_arguments)]
116fn launch_with_algorithm<R: Runtime, Alg: Algorithm>(
117    client: &ComputeClient<R>,
118    out_grad: InputBinding<R>,
119    weights: InputBinding<R>,
120    in_grad: TensorBinding<R>,
121    (stride, padding, dilation): (&[usize], &[usize], &[usize]),
122    dimensionality: Dimensionality,
123    blueprint_strategy: &BlueprintStrategy<RuntimeArgs, Alg::Routine>,
124    dtypes: MatmulElems,
125) -> Result<(), ConvSetupError>
126where
127    Alg::Args: ConcreteArgs<Alg::Routine>,
128{
129    let rank = in_grad.shape.len();
130    let dim_c = rank - 1;
131
132    let n = in_grad.shape[0];
133    let c = in_grad.shape[dim_c];
134
135    let out_c = out_grad.shape()[dim_c];
136
137    let in_shape = &in_grad.shape[1..dim_c];
138    let kernel_shape = &weights.shape()[1..dim_c];
139    let out_shape = &out_grad.shape()[1..dim_c];
140
141    let op = ConvolutionOperation::BackwardData;
142
143    let out_grad_tmp = out_grad.clone();
144    let weights_tmp = weights.clone();
145
146    let out_grad_data =
147        Alg::correct_layout(client, out_grad_tmp.into_data(), dtypes.lhs_global, op)?;
148    let weights_data = Alg::correct_layout(client, weights_tmp.into_data(), dtypes.rhs_global, op)?;
149
150    let mut out_grad = out_grad.clone();
151    let mut weights = weights.clone();
152
153    *out_grad.data_mut() = out_grad_data;
154    *weights.data_mut() = weights_data;
155
156    let address_type = out_grad
157        .required_address_type()
158        .max(weights.required_address_type())
159        .max(in_grad.required_address_type(dtypes.acc_global.size()));
160
161    let problem = ConvolutionProblem {
162        m: n * in_shape.iter().product::<usize>(),
163        n: c,
164        k: out_c * kernel_shape.iter().product::<usize>(),
165
166        lhs_strides: out_grad.data().strides.clone(),
167        rhs_strides: weights.data().strides.clone(),
168        lhs_layout: MatrixLayout::RowMajor,
169        rhs_layout: MatrixLayout::RowMajor,
170        kernel_size: kernel_shape.iter().map(|it| *it as u32).collect(),
171        stride: stride.iter().map(|it| *it as u32).collect(),
172        padding: padding.iter().map(|it| *it as i32).collect(),
173        dilation: dilation.iter().map(|it| *it as u32).collect(),
174
175        batches: n,
176        in_shape: in_shape.into(),
177        out_shape: out_shape.into(),
178        channels: c,
179        out_channels: out_c,
180
181        padded_channels: out_c,
182        operation: op,
183
184        dimensionality,
185        global_dtypes: dtypes.as_global_elems(),
186        address_type,
187    };
188
189    launch_kernel::<R, Alg>(
190        client,
191        out_grad,
192        weights,
193        in_grad,
194        problem,
195        blueprint_strategy,
196        dtypes,
197    )
198}
199
200#[allow(clippy::result_large_err, clippy::too_many_arguments)]
201pub fn launch_kernel<R: Runtime, Alg: Algorithm>(
202    client: &ComputeClient<R>,
203    out_grad: InputBinding<R>,
204    weights: InputBinding<R>,
205    in_grad: TensorBinding<R>,
206    problem: ConvolutionProblem,
207    blueprint_strategy: &BlueprintStrategy<RuntimeArgs, Alg::Routine>,
208    dtypes: MatmulElems,
209) -> Result<(), ConvSetupError>
210where
211    Alg::Args: ConcreteArgs<Alg::Routine>,
212{
213    // Shape/strides are treated as k-major, with the last dim always being the contiguous one.
214    // So for the sake of selecting a vector size, the shape/strides are always row-major.
215    let vector_sizes = AvailableVectorSizes::from_type_sizes(
216        client,
217        out_grad.data_elem_size(),
218        weights.data_elem_size(),
219        dtypes.acc_global.size(),
220    )
221    .filter_lhs_with_tensor(
222        &out_grad.data().strides,
223        &out_grad.data().shape,
224        MatrixLayout::RowMajor,
225    )
226    .filter_rhs_with_tensor(
227        &weights.data().strides,
228        &weights.data().shape,
229        MatrixLayout::RowMajor,
230    )
231    .filter_out_with_tensor(&in_grad.strides, &in_grad.shape);
232
233    let vector_sizes = Alg::filter_vector_sizes(vector_sizes).pick_max()?;
234
235    launch_kernel_concrete::<R, Alg::Args, Alg::Routine>(
236        client,
237        out_grad,
238        weights,
239        in_grad,
240        problem,
241        vector_sizes,
242        blueprint_strategy,
243        &dtypes,
244    )
245}