Skip to main content

cubek_convolution/kernels/backward_weight/
launch.rs

1use crate::components::{ConvolutionProblem, Dimensionality};
2use crate::{
3    AcceleratedTileKind, ReadingStrategy, algorithm::Algorithm,
4    components::global::args::RuntimeArgs,
5};
6use crate::{
7    ConvolutionArgs, Strategy, backward_weight::args::ConcreteArgs,
8    components::ConvolutionOperation, kernels::algorithm::simple::*,
9};
10use crate::{
11    components::ConvSetupError, kernels::backward_weight::selector::launch_kernel_concrete,
12};
13use cubecl::{
14    Runtime,
15    client::ComputeClient,
16    prelude::*,
17    std::{CubeOption, tensor::TensorHandle},
18};
19use cubek_matmul::launch::{MatmulInputHandle, MatmulInputHandleRef};
20use cubek_matmul::{
21    components::tile::{cmma::CmmaMatmul, io::Strided, mma::MmaMatmul},
22    definition,
23};
24use cubek_matmul::{
25    definition::{AvailableLineSizes, MatmulElems, MatrixLayout},
26    routines::BlueprintStrategy,
27};
28use derive_new::new;
29
30macro_rules! with_tile_kind {
31    ($kind: expr, $T: ident, $launch: expr) => {
32        match $kind {
33            AcceleratedTileKind::Cmma => {
34                type $T = CmmaMatmul<CubeOption<Strided>>;
35                ($launch)()
36            }
37            AcceleratedTileKind::Mma => {
38                type $T = MmaMatmul<Strided, Strided, CubeOption<Strided>>;
39                ($launch)()
40            }
41        }
42    };
43}
44
45#[allow(clippy::result_large_err, clippy::too_many_arguments)]
46pub fn launch<R: Runtime, const N_SPATIAL: usize>(
47    strategy: &Strategy,
48    client: &ComputeClient<R>,
49    input: MatmulInputHandle<R>,
50    out_grad: MatmulInputHandle<R>,
51    weight_grad: TensorHandle<R>,
52    args: ConvolutionArgs<N_SPATIAL>,
53    dtypes: MatmulElems,
54) -> Result<(), ConvSetupError> {
55    launch_ref(
56        strategy,
57        client,
58        &input.as_ref(),
59        &out_grad.as_ref(),
60        &weight_grad.as_ref(),
61        args,
62        dtypes,
63    )
64}
65
66/// Perform an n-dimensional convolution using the implicit GEMM (im2col) algorithm, using cubecl
67/// tiling matmul components, using the specified algorithm.
68///
69/// * `input` - The input feature map, layout should be [batches, depth, height, width, in_channels]
70/// * `weight` - The weights (filter) applied to each kernel, layout should be [out_channels, kernel_d, kernel_h, kernel_w, in_channels]
71/// * `out` - The output feature map, layout should be [batches, out_depth, out_height, out_width, out_channels]
72/// * `bias` - The bias added to each out channel
73/// * `options` - The options to use for the convolution
74#[allow(clippy::result_large_err, clippy::too_many_arguments)]
75pub fn launch_ref<R: Runtime, const N_SPATIAL: usize>(
76    strategy: &Strategy,
77    client: &ComputeClient<R>,
78    input: &MatmulInputHandleRef<'_, R>,
79    out_grad: &MatmulInputHandleRef<'_, R>,
80    weight_grad: &TensorHandleRef<'_, R>,
81    args: ConvolutionArgs<N_SPATIAL>,
82    dtypes: MatmulElems,
83) -> Result<(), ConvSetupError> {
84    let backprop = BackwardsWeight::new(client, input, out_grad, weight_grad, args, dtypes);
85
86    match strategy {
87        Strategy::Simple {
88            read_strategy,
89            tile_kind,
90        } => with_tile_kind!(tile_kind, Accelerated, || match read_strategy {
91            ReadingStrategy::Cyclic => backprop.launch::<SimpleSyncCyclicConv<Accelerated>>(),
92            ReadingStrategy::Strided => backprop.launch::<SimpleSyncStridedConv<Accelerated>>(),
93            ReadingStrategy::Tilewise => backprop.launch::<SimpleSyncTilewiseConv<Accelerated>>(),
94            ReadingStrategy::AsyncCyclic => backprop.launch::<SimpleAsyncCyclicConv<Accelerated>>(),
95            ReadingStrategy::AsyncStrided =>
96                backprop.launch::<SimpleAsyncStridedConv<Accelerated>>(),
97            ReadingStrategy::Tma => backprop.launch::<SimpleAsyncTmaConv<Accelerated>>(),
98        }),
99    }
100}
101
102#[derive(new)]
103struct BackwardsWeight<'a, R: Runtime, const N_SPATIAL: usize> {
104    client: &'a ComputeClient<R>,
105    input: &'a MatmulInputHandleRef<'a, R>,
106    out_grad: &'a MatmulInputHandleRef<'a, R>,
107    weight_grad: &'a TensorHandleRef<'a, R>,
108    args: ConvolutionArgs<N_SPATIAL>,
109    dtypes: MatmulElems,
110}
111
112impl<'a, R: Runtime, const N_SPATIAL: usize> BackwardsWeight<'a, R, N_SPATIAL> {
113    fn launch<Alg: Algorithm>(self) -> Result<(), ConvSetupError>
114    where
115        Alg::Args: ConcreteArgs<Alg::Routine>,
116    {
117        let ConvolutionArgs {
118            stride,
119            padding,
120            dilation,
121        } = self.args;
122
123        let dimensionality = match N_SPATIAL {
124            1 => Dimensionality::Dim1,
125            2 => Dimensionality::Dim2,
126            3 => Dimensionality::Dim3,
127            other => unimplemented!("Unsupported dimensionality {other}"),
128        };
129
130        launch_with_algorithm::<R, Alg>(
131            self.client,
132            self.input,
133            self.out_grad,
134            self.weight_grad,
135            (&stride, &padding, &dilation),
136            dimensionality,
137            self.dtypes,
138        )
139    }
140}
141
142#[allow(clippy::too_many_arguments)]
143fn launch_with_algorithm<R: Runtime, Alg: Algorithm>(
144    client: &ComputeClient<R>,
145    input: &MatmulInputHandleRef<'_, R>,
146    out_grad: &MatmulInputHandleRef<'_, R>,
147    weight_grad: &TensorHandleRef<'_, R>,
148    (stride, padding, dilation): (&[usize], &[usize], &[usize]),
149    dimensionality: Dimensionality,
150    dtypes: MatmulElems,
151) -> Result<(), ConvSetupError>
152where
153    Alg::Args: ConcreteArgs<Alg::Routine>,
154{
155    let rank = input.data().shape.len();
156    let dim_c = rank - 1;
157
158    let n = input.shape()[0];
159    let c = input.shape()[dim_c];
160
161    let out_c = out_grad.shape()[dim_c];
162
163    let in_shape = &input.shape()[1..dim_c];
164    let kernel_shape = &weight_grad.shape[1..dim_c];
165    let out_shape = &out_grad.shape()[1..dim_c];
166
167    let op = ConvolutionOperation::BackwardWeight;
168
169    let input_data = Alg::into_tensor_handle(client, input.data(), dtypes.lhs_global, op)?;
170    let out_grad_data = Alg::into_tensor_handle(client, out_grad.data(), dtypes.rhs_global, op)?;
171
172    let mut input = *input;
173    let mut out_grad = *out_grad;
174
175    *input.data_mut() = input_data.as_ref();
176    *out_grad.data_mut() = out_grad_data.as_ref();
177
178    let problem = ConvolutionProblem {
179        m: out_c,
180        n: c * kernel_shape.iter().product::<usize>(),
181        k: n * out_shape.iter().product::<usize>(),
182        lhs_strides: input.data().strides.to_vec(),
183        rhs_strides: out_grad.data().strides.to_vec(),
184        lhs_layout: definition::MatrixLayout::ColMajor,
185        rhs_layout: definition::MatrixLayout::RowMajor,
186        kernel_size: kernel_shape.iter().map(|it| *it as u32).collect(),
187        stride: stride.iter().map(|it| *it as u32).collect(),
188        padding: padding.iter().map(|it| *it as i32).collect(),
189        dilation: dilation.iter().map(|it| *it as u32).collect(),
190
191        batches: n,
192        in_shape: in_shape.to_vec(),
193        out_shape: out_shape.to_vec(),
194        channels: c,
195        out_channels: out_c,
196
197        padded_channels: c,
198        operation: op,
199
200        dimensionality,
201        global_dtypes: dtypes.as_global_elems(),
202    };
203
204    launch_kernel::<R, Alg>(
205        client,
206        &input,
207        &out_grad,
208        weight_grad,
209        problem,
210        &BlueprintStrategy::Inferred(Default::default()),
211        dtypes,
212    )
213}
214
215#[allow(clippy::result_large_err, clippy::too_many_arguments)]
216pub fn launch_kernel<R: Runtime, Alg: Algorithm>(
217    client: &ComputeClient<R>,
218    input: &MatmulInputHandleRef<'_, R>,
219    out_grad: &MatmulInputHandleRef<'_, R>,
220    weight_grad: &TensorHandleRef<'_, R>,
221    problem: ConvolutionProblem,
222    blueprint_strategy: &BlueprintStrategy<RuntimeArgs, Alg::Routine>,
223    dtypes: MatmulElems,
224) -> Result<(), ConvSetupError>
225where
226    Alg::Args: ConcreteArgs<Alg::Routine>,
227{
228    // Shape/strides are treated as k-major, with the last dim always being the contiguous one.
229    // So for the sake of selecting a line size, the shape/strides are always row-major.
230    let line_sizes = AvailableLineSizes::from_type_sizes(
231        client,
232        input.data().elem_size,
233        out_grad.data().elem_size,
234        weight_grad.elem_size,
235    )
236    .filter_lhs_with_tensor(
237        out_grad.data().strides,
238        out_grad.data().shape,
239        MatrixLayout::RowMajor,
240    )
241    .filter_rhs_with_tensor(
242        input.data().strides,
243        input.data().shape,
244        MatrixLayout::RowMajor,
245    )
246    .filter_out_with_tensor(weight_grad.strides, weight_grad.shape);
247
248    let line_sizes = Alg::filter_line_sizes(line_sizes).pick_max()?;
249
250    launch_kernel_concrete::<R, Alg::Args, Alg::Routine>(
251        client,
252        input,
253        out_grad,
254        weight_grad,
255        problem,
256        line_sizes,
257        blueprint_strategy,
258        &dtypes,
259    )
260}