Skip to main content

cubek_convolution/kernels/backward_weight/
launch.rs

1use crate::components::{ConvolutionProblem, Dimensionality};
2use crate::{
3    AcceleratedTileKind, ReadingStrategy, algorithm::Algorithm,
4    components::global::args::RuntimeArgs,
5};
6use crate::{
7    ConvolutionArgs, Strategy, backward_weight::args::ConcreteArgs,
8    components::ConvolutionOperation, kernels::algorithm::simple::*,
9};
10use crate::{
11    components::ConvSetupError, kernels::backward_weight::selector::launch_kernel_concrete,
12};
13use cubecl::{Runtime, client::ComputeClient, prelude::*};
14use cubek_matmul::components::tile::{cmma::CmmaMatmul, mma::MmaMatmul};
15use cubek_matmul::{
16    definition::{AvailableVectorSizes, MatmulElems},
17    routines::BlueprintStrategy,
18};
19use cubek_std::{
20    InputBinding,
21    {MatrixLayout, tile::Strided},
22};
23use derive_new::new;
24
25macro_rules! with_tile_kind {
26    ($kind: expr, $T: ident, $launch: expr) => {
27        match $kind {
28            AcceleratedTileKind::Cmma => {
29                type $T = CmmaMatmul<Option<Strided>>;
30                ($launch)()
31            }
32            AcceleratedTileKind::Mma => {
33                type $T = MmaMatmul<Strided, Strided, Option<Strided>>;
34                ($launch)()
35            }
36        }
37    };
38}
39
40/// Perform an n-dimensional convolution using the implicit GEMM (im2col) algorithm, using cubecl
41/// tiling matmul components, using the specified algorithm.
42///
43/// * `input` - The input feature map, layout should be [batches, depth, height, width, in_channels]
44/// * `weight` - The weights (filter) applied to each kernel, layout should be [out_channels, kernel_d, kernel_h, kernel_w, in_channels]
45/// * `out` - The output feature map, layout should be [batches, out_depth, out_height, out_width, out_channels]
46/// * `bias` - The bias added to each out channel
47/// * `options` - The options to use for the convolution
48#[allow(clippy::result_large_err, clippy::too_many_arguments)]
49pub fn launch_ref<R: Runtime, const N_SPATIAL: usize>(
50    strategy: &Strategy,
51    client: &ComputeClient<R>,
52    input: InputBinding<R>,
53    out_grad: InputBinding<R>,
54    weight_grad: TensorBinding<R>,
55    args: ConvolutionArgs<N_SPATIAL>,
56    dtypes: MatmulElems,
57) -> Result<(), ConvSetupError> {
58    let backprop = BackwardsWeight::new(client, input, out_grad, weight_grad, args, dtypes);
59
60    match strategy {
61        Strategy::Simple {
62            read_strategy,
63            tile_kind,
64        } => with_tile_kind!(tile_kind, Accelerated, || match read_strategy {
65            ReadingStrategy::Cyclic => backprop.launch::<SimpleSyncCyclicConv<Accelerated>>(),
66            ReadingStrategy::Strided => backprop.launch::<SimpleSyncStridedConv<Accelerated>>(),
67            ReadingStrategy::Tilewise => backprop.launch::<SimpleSyncTilewiseConv<Accelerated>>(),
68            ReadingStrategy::AsyncCyclic => backprop.launch::<SimpleAsyncCyclicConv<Accelerated>>(),
69            ReadingStrategy::AsyncStrided =>
70                backprop.launch::<SimpleAsyncStridedConv<Accelerated>>(),
71            ReadingStrategy::Tma => backprop.launch::<SimpleAsyncTmaConv<Accelerated>>(),
72        }),
73    }
74}
75
76#[derive(new)]
77struct BackwardsWeight<'a, R: Runtime, const N_SPATIAL: usize> {
78    client: &'a ComputeClient<R>,
79    input: InputBinding<R>,
80    out_grad: InputBinding<R>,
81    weight_grad: TensorBinding<R>,
82    args: ConvolutionArgs<N_SPATIAL>,
83    dtypes: MatmulElems,
84}
85
86impl<'a, R: Runtime, const N_SPATIAL: usize> BackwardsWeight<'a, R, N_SPATIAL> {
87    fn launch<Alg: Algorithm>(self) -> Result<(), ConvSetupError>
88    where
89        Alg::Args: ConcreteArgs<Alg::Routine>,
90    {
91        let ConvolutionArgs {
92            stride,
93            padding,
94            dilation,
95        } = self.args;
96
97        let dimensionality = match N_SPATIAL {
98            1 => Dimensionality::Dim1,
99            2 => Dimensionality::Dim2,
100            3 => Dimensionality::Dim3,
101            other => unimplemented!("Unsupported dimensionality {other}"),
102        };
103
104        launch_with_algorithm::<R, Alg>(
105            self.client,
106            self.input,
107            self.out_grad,
108            self.weight_grad,
109            (&stride, &padding, &dilation),
110            dimensionality,
111            self.dtypes,
112        )
113    }
114}
115
116#[allow(clippy::too_many_arguments)]
117fn launch_with_algorithm<R: Runtime, Alg: Algorithm>(
118    client: &ComputeClient<R>,
119    input: InputBinding<R>,
120    out_grad: InputBinding<R>,
121    weight_grad: TensorBinding<R>,
122    (stride, padding, dilation): (&[usize], &[usize], &[usize]),
123    dimensionality: Dimensionality,
124    dtypes: MatmulElems,
125) -> Result<(), ConvSetupError>
126where
127    Alg::Args: ConcreteArgs<Alg::Routine>,
128{
129    let rank = input.data().shape.len();
130    let dim_c = rank - 1;
131
132    let n = input.shape()[0];
133    let c = input.shape()[dim_c];
134
135    let out_c = out_grad.shape()[dim_c];
136
137    let in_shape = &input.shape()[1..dim_c];
138    let kernel_shape = &weight_grad.shape[1..dim_c];
139    let out_shape = &out_grad.shape()[1..dim_c];
140
141    let op = ConvolutionOperation::BackwardWeight;
142
143    let input_data = Alg::correct_layout(client, input.clone().into_data(), dtypes.lhs_global, op)?;
144    let out_grad_data =
145        Alg::correct_layout(client, out_grad.clone().into_data(), dtypes.rhs_global, op)?;
146
147    let mut input = input.clone();
148    let mut out_grad = out_grad.clone();
149
150    *input.data_mut() = input_data;
151    *out_grad.data_mut() = out_grad_data;
152
153    let address_type = input
154        .required_address_type()
155        .max(out_grad.required_address_type())
156        .max(weight_grad.required_address_type(dtypes.acc_global.size()));
157
158    let problem = ConvolutionProblem {
159        m: out_c,
160        n: c * kernel_shape.iter().product::<usize>(),
161        k: n * out_shape.iter().product::<usize>(),
162        lhs_strides: input.data().strides.clone(),
163        rhs_strides: out_grad.data().strides.clone(),
164        lhs_layout: MatrixLayout::ColMajor,
165        rhs_layout: MatrixLayout::RowMajor,
166        kernel_size: kernel_shape.iter().map(|it| *it as u32).collect(),
167        stride: stride.iter().map(|it| *it as u32).collect(),
168        padding: padding.iter().map(|it| *it as i32).collect(),
169        dilation: dilation.iter().map(|it| *it as u32).collect(),
170
171        batches: n,
172        in_shape: in_shape.into(),
173        out_shape: out_shape.into(),
174        channels: c,
175        out_channels: out_c,
176
177        padded_channels: c,
178        operation: op,
179
180        dimensionality,
181        global_dtypes: dtypes.as_global_elems(),
182        address_type,
183    };
184
185    launch_kernel::<R, Alg>(
186        client,
187        input,
188        out_grad,
189        weight_grad,
190        problem,
191        &BlueprintStrategy::Inferred(Default::default()),
192        dtypes,
193    )
194}
195
196#[allow(clippy::result_large_err, clippy::too_many_arguments)]
197pub fn launch_kernel<R: Runtime, Alg: Algorithm>(
198    client: &ComputeClient<R>,
199    input: InputBinding<R>,
200    out_grad: InputBinding<R>,
201    weight_grad: TensorBinding<R>,
202    problem: ConvolutionProblem,
203    blueprint_strategy: &BlueprintStrategy<RuntimeArgs, Alg::Routine>,
204    dtypes: MatmulElems,
205) -> Result<(), ConvSetupError>
206where
207    Alg::Args: ConcreteArgs<Alg::Routine>,
208{
209    // Shape/strides are treated as k-major, with the last dim always being the contiguous one.
210    // So for the sake of selecting a vector size, the shape/strides are always row-major.
211    let vector_sizes = AvailableVectorSizes::from_type_sizes(
212        client,
213        input.data_elem_size(),
214        out_grad.data_elem_size(),
215        dtypes.acc_global.size(),
216    )
217    .filter_lhs_with_tensor(
218        &out_grad.data().strides,
219        &out_grad.data().shape,
220        MatrixLayout::RowMajor,
221    )
222    .filter_rhs_with_tensor(
223        &input.data().strides,
224        &input.data().shape,
225        MatrixLayout::RowMajor,
226    )
227    .filter_out_with_tensor(&weight_grad.strides, &weight_grad.shape);
228
229    let vector_sizes = Alg::filter_vector_sizes(vector_sizes).pick_max()?;
230
231    launch_kernel_concrete::<R, Alg::Args, Alg::Routine>(
232        client,
233        input,
234        out_grad,
235        weight_grad,
236        problem,
237        vector_sizes,
238        blueprint_strategy,
239        &dtypes,
240    )
241}