Skip to main content

cubek_convolution/kernels/backward_weight/
launch.rs

1use crate::components::{ConvolutionProblem, Dimensionality};
2use crate::{
3    AcceleratedTileKind, ReadingStrategy, algorithm::Algorithm,
4    components::global::args::RuntimeArgs,
5};
6use crate::{
7    ConvolutionArgs, Strategy, backward_weight::args::ConcreteArgs,
8    components::ConvolutionOperation, kernels::algorithm::simple::*,
9};
10use crate::{
11    components::ConvSetupError, kernels::backward_weight::selector::launch_kernel_concrete,
12};
13use cubecl::{Runtime, client::ComputeClient, prelude::*};
14use cubek_matmul::components::tile_matmul::DispatchTileMatmul;
15use cubek_matmul::{
16    definition::{AvailableVectorSizes, MatmulElems},
17    routines::{BlueprintStrategy, Routine, TilingArgs},
18};
19use cubek_std::{InputBinding, MatrixLayout};
20use derive_new::new;
21
22fn tile_kind_to_dispatch(kind: &AcceleratedTileKind) -> DispatchTileMatmul {
23    match kind {
24        AcceleratedTileKind::Cmma => DispatchTileMatmul::Cmma,
25        AcceleratedTileKind::Mma => DispatchTileMatmul::Mma,
26    }
27}
28
29/// Perform an n-dimensional convolution using the implicit GEMM (im2col) algorithm, using cubecl
30/// tiling matmul components, using the specified algorithm.
31///
32/// * `input` - The input feature map, layout should be [batches, depth, height, width, in_channels]
33/// * `weight` - The weights (filter) applied to each kernel, layout should be [out_channels, kernel_d, kernel_h, kernel_w, in_channels]
34/// * `out` - The output feature map, layout should be [batches, out_depth, out_height, out_width, out_channels]
35/// * `bias` - The bias added to each out channel
36/// * `options` - The options to use for the convolution
37#[allow(clippy::result_large_err, clippy::too_many_arguments)]
38pub fn launch_ref<R: Runtime, const N_SPATIAL: usize>(
39    strategy: &Strategy,
40    client: &ComputeClient<R>,
41    input: InputBinding<R>,
42    out_grad: InputBinding<R>,
43    weight_grad: TensorBinding<R>,
44    args: ConvolutionArgs<N_SPATIAL>,
45    dtypes: MatmulElems,
46) -> Result<(), ConvSetupError> {
47    let backprop = BackwardsWeight::new(client, input, out_grad, weight_grad, args, dtypes);
48
49    match strategy {
50        Strategy::Simple {
51            read_strategy,
52            tile_kind,
53        } => {
54            let kind = tile_kind_to_dispatch(tile_kind);
55            match read_strategy {
56                ReadingStrategy::Cyclic => backprop.launch::<SimpleSyncCyclicConv>(kind),
57                ReadingStrategy::Strided => backprop.launch::<SimpleSyncStridedConv>(kind),
58                ReadingStrategy::Tilewise => backprop.launch::<SimpleSyncTilewiseConv>(kind),
59                ReadingStrategy::AsyncCyclic => backprop.launch::<SimpleAsyncCyclicConv>(kind),
60                ReadingStrategy::AsyncStrided => backprop.launch::<SimpleAsyncStridedConv>(kind),
61                ReadingStrategy::Tma => backprop.launch::<SimpleAsyncTmaConv>(kind),
62            }
63        }
64    }
65}
66
67#[derive(new)]
68struct BackwardsWeight<'a, R: Runtime, const N_SPATIAL: usize> {
69    client: &'a ComputeClient<R>,
70    input: InputBinding<R>,
71    out_grad: InputBinding<R>,
72    weight_grad: TensorBinding<R>,
73    args: ConvolutionArgs<N_SPATIAL>,
74    dtypes: MatmulElems,
75}
76
77impl<'a, R: Runtime, const N_SPATIAL: usize> BackwardsWeight<'a, R, N_SPATIAL> {
78    fn launch<Alg: Algorithm>(self, tile_matmul: DispatchTileMatmul) -> Result<(), ConvSetupError>
79    where
80        Alg::Args: ConcreteArgs<Alg::Routine>,
81        <Alg::Routine as Routine<RuntimeArgs>>::Strategy: TilingArgs,
82    {
83        let ConvolutionArgs {
84            stride,
85            padding,
86            dilation,
87        } = self.args;
88
89        let dimensionality = match N_SPATIAL {
90            1 => Dimensionality::Dim1,
91            2 => Dimensionality::Dim2,
92            3 => Dimensionality::Dim3,
93            other => unimplemented!("Unsupported dimensionality {other}"),
94        };
95
96        launch_with_algorithm::<R, Alg>(
97            self.client,
98            self.input,
99            self.out_grad,
100            self.weight_grad,
101            (&stride, &padding, &dilation),
102            dimensionality,
103            tile_matmul,
104            self.dtypes,
105        )
106    }
107}
108
109#[allow(clippy::too_many_arguments)]
110fn launch_with_algorithm<R: Runtime, Alg: Algorithm>(
111    client: &ComputeClient<R>,
112    input: InputBinding<R>,
113    out_grad: InputBinding<R>,
114    weight_grad: TensorBinding<R>,
115    (stride, padding, dilation): (&[usize], &[usize], &[usize]),
116    dimensionality: Dimensionality,
117    tile_matmul: DispatchTileMatmul,
118    dtypes: MatmulElems,
119) -> Result<(), ConvSetupError>
120where
121    Alg::Args: ConcreteArgs<Alg::Routine>,
122    <Alg::Routine as Routine<RuntimeArgs>>::Strategy: TilingArgs,
123{
124    let rank = input.data().shape.len();
125    let dim_c = rank - 1;
126
127    let n = input.shape()[0];
128    let c = input.shape()[dim_c];
129
130    let out_c = out_grad.shape()[dim_c];
131
132    let in_shape = &input.shape()[1..dim_c];
133    let kernel_shape = &weight_grad.shape[1..dim_c];
134    let out_shape = &out_grad.shape()[1..dim_c];
135
136    let op = ConvolutionOperation::BackwardWeight;
137
138    let input_data = Alg::correct_layout(client, input.clone().into_data(), dtypes.lhs_global, op)?;
139    let out_grad_data =
140        Alg::correct_layout(client, out_grad.clone().into_data(), dtypes.rhs_global, op)?;
141
142    let mut input = input.clone();
143    let mut out_grad = out_grad.clone();
144
145    *input.data_mut() = input_data;
146    *out_grad.data_mut() = out_grad_data;
147
148    let address_type = input
149        .required_address_type()
150        .max(out_grad.required_address_type())
151        .max(weight_grad.required_address_type(dtypes.acc_global.size()));
152
153    let problem = ConvolutionProblem {
154        m: out_c,
155        n: c * kernel_shape.iter().product::<usize>(),
156        k: n * out_shape.iter().product::<usize>(),
157        lhs_strides: input.data().strides.clone(),
158        rhs_strides: out_grad.data().strides.clone(),
159        lhs_layout: MatrixLayout::ColMajor,
160        rhs_layout: MatrixLayout::RowMajor,
161        kernel_size: kernel_shape.iter().map(|it| *it as u32).collect(),
162        stride: stride.iter().map(|it| *it as u32).collect(),
163        padding: padding.iter().map(|it| *it as i32).collect(),
164        dilation: dilation.iter().map(|it| *it as u32).collect(),
165
166        batches: n,
167        in_shape: in_shape.into(),
168        out_shape: out_shape.into(),
169        channels: c,
170        out_channels: out_c,
171
172        padded_channels: c,
173        operation: op,
174
175        dimensionality,
176        global_dtypes: dtypes.as_global_elems(),
177        address_type,
178    };
179
180    let mut args = <Alg::Routine as Routine<RuntimeArgs>>::Strategy::default();
181    args.set_tile_matmul(tile_matmul);
182
183    launch_kernel::<R, Alg>(
184        client,
185        input,
186        out_grad,
187        weight_grad,
188        problem,
189        &BlueprintStrategy::Inferred(args),
190        dtypes,
191    )
192}
193
194#[allow(clippy::result_large_err, clippy::too_many_arguments)]
195pub fn launch_kernel<R: Runtime, Alg: Algorithm>(
196    client: &ComputeClient<R>,
197    input: InputBinding<R>,
198    out_grad: InputBinding<R>,
199    weight_grad: TensorBinding<R>,
200    problem: ConvolutionProblem,
201    blueprint_strategy: &BlueprintStrategy<RuntimeArgs, Alg::Routine>,
202    dtypes: MatmulElems,
203) -> Result<(), ConvSetupError>
204where
205    Alg::Args: ConcreteArgs<Alg::Routine>,
206{
207    // Shape/strides are treated as k-major, with the last dim always being the contiguous one.
208    // So for the sake of selecting a vector size, the shape/strides are always row-major.
209    let vector_sizes = AvailableVectorSizes::from_type_sizes(
210        client,
211        input.data_elem_size(),
212        out_grad.data_elem_size(),
213        dtypes.acc_global.size(),
214    )
215    .filter_lhs_with_tensor(
216        &out_grad.data().strides,
217        &out_grad.data().shape,
218        MatrixLayout::RowMajor,
219    )
220    .filter_rhs_with_tensor(
221        &input.data().strides,
222        &input.data().shape,
223        MatrixLayout::RowMajor,
224    )
225    .filter_out_with_tensor(&weight_grad.strides, &weight_grad.shape);
226
227    let vector_sizes = Alg::filter_vector_sizes(vector_sizes).pick_max()?;
228
229    launch_kernel_concrete::<R, Alg::Args, Alg::Routine>(
230        client,
231        input,
232        out_grad,
233        weight_grad,
234        problem,
235        vector_sizes,
236        blueprint_strategy,
237        &dtypes,
238    )
239}