cubek_convolution/kernels/forward/
launch.rs

1use crate::{AcceleratedTileKind, ReadingStrategy};
2use crate::{
3    ConvolutionArgs, Strategy,
4    components::{ConvGemmConfig as _, ConvolutionOperation},
5    forward::args::ConcreteArgs,
6    kernels::forward::simple::*,
7};
8use crate::{components::ConvSetupError, kernels::forward::selector::launch_kernel_concrete};
9use crate::{
10    components::{ConvolutionProblem, Dimensionality},
11    kernels::forward::algorithm::Algorithm,
12};
13use cubecl::{
14    Runtime,
15    client::ComputeClient,
16    prelude::*,
17    std::{CubeOption, tensor::TensorHandle},
18};
19use cubek_matmul::launch::MatmulInputHandle;
20use cubek_matmul::{
21    components::tile::{cmma::CmmaMatmul, io::Strided, mma::MmaMatmul},
22    definition::{AvailableLineSizes, MatmulElems, MatrixLayout},
23};
24use cubek_matmul::{definition, launch::MatmulInputHandleRef};
25use derive_new::new;
26
27macro_rules! with_tile_kind {
28    ($kind: expr, $T: ident, $launch: expr) => {
29        match $kind {
30            AcceleratedTileKind::Cmma => {
31                type $T = CmmaMatmul<CubeOption<Strided>>;
32                ($launch)()
33            }
34            AcceleratedTileKind::Mma => {
35                type $T = MmaMatmul<Strided, Strided, CubeOption<Strided>>;
36                ($launch)()
37            }
38        }
39    };
40}
41
42#[allow(clippy::result_large_err, clippy::too_many_arguments)]
43pub fn launch<R: Runtime, const N_SPATIAL: usize>(
44    strategy: &Strategy,
45    client: &ComputeClient<R>,
46    input: MatmulInputHandle<R>,
47    weight: MatmulInputHandle<R>,
48    bias: Option<MatmulInputHandle<R>>,
49    out: TensorHandle<R>,
50    args: ConvolutionArgs<N_SPATIAL>,
51    dtypes: MatmulElems,
52) -> Result<(), ConvSetupError> {
53    launch_ref(
54        strategy,
55        client,
56        &input.as_ref(),
57        &weight.as_ref(),
58        &bias.as_ref().map(|it| it.as_ref()),
59        &out.as_ref(),
60        args,
61        dtypes,
62    )
63}
64
65/// Perform an n-dimensional convolution using the implicit GEMM (im2col) algorithm, using cubecl
66/// tiling matmul components, using the specified algorithm.
67///
68/// * `input` - The input feature map, layout should be [batches, depth, height, width, in_channels]
69/// * `weight` - The weights (filter) applied to each kernel, layout should be [out_channels, kernel_d, kernel_h, kernel_w, in_channels]
70/// * `out` - The output feature map, layout should be [batches, out_depth, out_height, out_width, out_channels]
71/// * `bias` - The bias added to each out channel
72/// * `options` - The options to use for the convolution
73#[allow(clippy::result_large_err, clippy::too_many_arguments)]
74pub fn launch_ref<R: Runtime, const N_SPATIAL: usize>(
75    strategy: &Strategy,
76    client: &ComputeClient<R>,
77    input: &MatmulInputHandleRef<'_, R>,
78    weight: &MatmulInputHandleRef<'_, R>,
79    bias: &Option<MatmulInputHandleRef<'_, R>>,
80    out: &TensorHandleRef<'_, R>,
81    args: ConvolutionArgs<N_SPATIAL>,
82    dtypes: MatmulElems,
83) -> Result<(), ConvSetupError> {
84    let conv = Convolution::new(client, input, weight, bias, out, args, dtypes);
85
86    match strategy {
87        Strategy::Simple {
88            read_strategy,
89            tile_kind,
90        } => with_tile_kind!(tile_kind, Accelerated, || match read_strategy {
91            ReadingStrategy::Cyclic => conv.launch::<SimpleSyncCyclicConv<Accelerated>>(),
92            ReadingStrategy::Strided => conv.launch::<SimpleSyncStridedConv<Accelerated>>(),
93            ReadingStrategy::Tilewise => conv.launch::<SimpleSyncTilewiseConv<Accelerated>>(),
94            ReadingStrategy::AsyncCyclic => conv.launch::<SimpleAsyncCyclicConv<Accelerated>>(),
95            ReadingStrategy::AsyncStrided => conv.launch::<SimpleAsyncStridedConv<Accelerated>>(),
96            ReadingStrategy::Tma => conv.launch::<SimpleAsyncTmaConv<Accelerated>>(),
97        }),
98    }
99}
100
101#[derive(new)]
102struct Convolution<'a, R: Runtime, const N_SPATIAL: usize> {
103    client: &'a ComputeClient<R>,
104    input: &'a MatmulInputHandleRef<'a, R>,
105    weight: &'a MatmulInputHandleRef<'a, R>,
106    bias: &'a Option<MatmulInputHandleRef<'a, R>>,
107    out: &'a TensorHandleRef<'a, R>,
108    args: ConvolutionArgs<N_SPATIAL>,
109    dtypes: MatmulElems,
110}
111
112impl<'a, R: Runtime, const N_SPATIAL: usize> Convolution<'a, R, N_SPATIAL> {
113    fn launch<Alg: Algorithm>(self) -> Result<(), ConvSetupError>
114    where
115        Alg::Args: ConcreteArgs,
116    {
117        let ConvolutionArgs {
118            stride,
119            padding,
120            dilation,
121        } = self.args;
122
123        let dimensionality = match N_SPATIAL {
124            1 => Dimensionality::Dim1,
125            2 => Dimensionality::Dim2,
126            3 => Dimensionality::Dim3,
127            other => unimplemented!("Unsupported dimensionality {other}"),
128        };
129
130        launch_with_algorithm::<R, Alg>(
131            self.client,
132            self.input,
133            self.weight,
134            self.bias,
135            self.out,
136            (&stride, &padding, &dilation),
137            dimensionality,
138            self.dtypes,
139        )
140    }
141}
142
143#[allow(clippy::too_many_arguments)]
144fn launch_with_algorithm<R: Runtime, Alg: Algorithm>(
145    client: &ComputeClient<R>,
146    input: &MatmulInputHandleRef<'_, R>,
147    weight: &MatmulInputHandleRef<'_, R>,
148    bias: &Option<MatmulInputHandleRef<'_, R>>,
149    out: &TensorHandleRef<'_, R>,
150    (stride, padding, dilation): (&[usize], &[usize], &[usize]),
151    dimensionality: Dimensionality,
152    dtypes: MatmulElems,
153) -> Result<(), ConvSetupError>
154where
155    Alg::Args: ConcreteArgs,
156{
157    let rank = input.data().shape.len();
158    let dim_c = rank - 1;
159
160    let n = input.data().shape[0];
161    let c = input.data().shape[dim_c];
162
163    let out_c = weight.data().shape[0];
164
165    let in_shape = &input.data().shape[1..dim_c];
166    let kernel_shape = &weight.data().shape[1..dim_c];
167    let out_shape = &out.shape[1..dim_c];
168
169    let op = ConvolutionOperation::Forward;
170
171    let input_data = Alg::into_tensor_handle(client, input.data(), dtypes.lhs_global, op)?;
172    let weight_data = Alg::into_tensor_handle(client, weight.data(), dtypes.rhs_global, op)?;
173
174    let mut input = *input;
175    let mut weight = *weight;
176
177    *input.data_mut() = input_data.as_ref();
178    *weight.data_mut() = weight_data.as_ref();
179
180    let problem = ConvolutionProblem {
181        m: n * out_shape.iter().product::<usize>(),
182        n: out_c,
183        k: c * kernel_shape.iter().product::<usize>(),
184        lhs_strides: input.data().strides.to_vec(),
185        rhs_strides: weight.data().strides.to_vec(),
186        lhs_layout: definition::MatrixLayout::RowMajor,
187        rhs_layout: definition::MatrixLayout::ColMajor,
188        kernel_size: kernel_shape.iter().map(|it| *it as u32).collect(),
189        stride: stride.iter().map(|it| *it as u32).collect(),
190        padding: padding.iter().map(|it| *it as i32).collect(),
191        dilation: dilation.iter().map(|it| *it as u32).collect(),
192
193        batches: n,
194        in_shape: in_shape.to_vec(),
195        out_shape: out_shape.to_vec(),
196        channels: c,
197        out_channels: out_c,
198
199        padded_channels: c,
200        operation: op,
201
202        dimensionality,
203        global_dtypes: dtypes.as_global_elems(),
204    };
205
206    launch_kernel::<R, Alg>(client, &input, &weight, bias, out, problem, dtypes)
207}
208
209#[allow(clippy::result_large_err, clippy::too_many_arguments)]
210pub fn launch_kernel<R: Runtime, Alg: Algorithm>(
211    client: &ComputeClient<R>,
212    input: &MatmulInputHandleRef<'_, R>,
213    weight: &MatmulInputHandleRef<'_, R>,
214    bias: &Option<MatmulInputHandleRef<'_, R>>,
215    out: &TensorHandleRef<'_, R>,
216    problem: ConvolutionProblem,
217    mut dtypes: MatmulElems,
218) -> Result<(), ConvSetupError>
219where
220    Alg::Args: ConcreteArgs,
221{
222    let plane_dim = client.properties().hardware.plane_size_max;
223    // Shape/strides are treated as k-major, with the last dim always being the contiguous one.
224    // So for the sake of selecting a line size, the shape/strides are always row-major.
225    let line_sizes = AvailableLineSizes::from_type_sizes(
226        client,
227        input.data().elem_size,
228        weight.data().elem_size,
229        out.elem_size,
230    )
231    .filter_lhs_with_tensor(
232        input.data().strides,
233        input.data().shape,
234        MatrixLayout::RowMajor,
235    )
236    .filter_rhs_with_tensor(
237        weight.data().strides,
238        weight.data().shape,
239        MatrixLayout::RowMajor,
240    )
241    .filter_out_with_tensor(out.strides, out.shape);
242
243    let line_sizes = Alg::filter_line_sizes(line_sizes).pick_max()?;
244
245    let selection = Alg::selection(client, &problem, plane_dim, &line_sizes, &mut dtypes)?;
246    let problem = Alg::Args::adjust_problem(client, problem, &selection, &dtypes);
247
248    let config = Alg::expand_config(
249        client.properties(),
250        &problem,
251        &selection,
252        &line_sizes,
253        &dtypes,
254    )?;
255
256    let line_sizes = config.line_sizes();
257
258    launch_kernel_concrete::<R, Alg>(
259        client, input, weight, bias, out, problem, line_sizes, selection, &dtypes,
260    )
261}