Skip to main content

cubek_convolution/kernels/forward/
launch.rs

1use crate::{
2    AcceleratedTileKind, ReadingStrategy, algorithm::simple::*,
3    components::global::args::RuntimeArgs,
4};
5use crate::{
6    ConvolutionArgs, Strategy, components::ConvolutionOperation, forward::args::ConcreteArgs,
7};
8use crate::{
9    algorithm::Algorithm,
10    components::{ConvolutionProblem, Dimensionality},
11};
12use crate::{components::ConvSetupError, kernels::forward::selector::launch_kernel_concrete};
13use cubecl::{
14    Runtime,
15    client::ComputeClient,
16    prelude::*,
17    std::{CubeOption, tensor::TensorHandle},
18};
19use cubek_matmul::{
20    components::tile::{cmma::CmmaMatmul, io::Strided, mma::MmaMatmul},
21    definition::{AvailableLineSizes, MatmulElems, MatrixLayout},
22};
23use cubek_matmul::{definition, launch::MatmulInputHandleRef};
24use cubek_matmul::{launch::MatmulInputHandle, routines::BlueprintStrategy};
25use derive_new::new;
26
27macro_rules! with_tile_kind {
28    ($kind: expr, $T: ident, $launch: expr) => {
29        match $kind {
30            AcceleratedTileKind::Cmma => {
31                type $T = CmmaMatmul<CubeOption<Strided>>;
32                ($launch)()
33            }
34            AcceleratedTileKind::Mma => {
35                type $T = MmaMatmul<Strided, Strided, CubeOption<Strided>>;
36                ($launch)()
37            }
38        }
39    };
40}
41
42#[allow(clippy::result_large_err, clippy::too_many_arguments)]
43pub fn launch<R: Runtime, const N_SPATIAL: usize>(
44    strategy: &Strategy,
45    client: &ComputeClient<R>,
46    input: MatmulInputHandle<R>,
47    weight: MatmulInputHandle<R>,
48    bias: Option<MatmulInputHandle<R>>,
49    out: TensorHandle<R>,
50    args: ConvolutionArgs<N_SPATIAL>,
51    dtypes: MatmulElems,
52) -> Result<(), ConvSetupError> {
53    launch_ref(
54        strategy,
55        client,
56        &input.as_ref(),
57        &weight.as_ref(),
58        &bias.as_ref().map(|it| it.as_ref()),
59        &out.as_ref(),
60        args,
61        dtypes,
62    )
63}
64
65/// Perform an n-dimensional convolution using the implicit GEMM (im2col) algorithm, using cubecl
66/// tiling matmul components, using the specified algorithm.
67///
68/// * `input` - The input feature map, layout should be [batches, depth, height, width, in_channels]
69/// * `weight` - The weights (filter) applied to each kernel, layout should be [out_channels, kernel_d, kernel_h, kernel_w, in_channels]
70/// * `out` - The output feature map, layout should be [batches, out_depth, out_height, out_width, out_channels]
71/// * `bias` - The bias added to each out channel
72/// * `options` - The options to use for the convolution
73#[allow(clippy::result_large_err, clippy::too_many_arguments)]
74pub fn launch_ref<R: Runtime, const N_SPATIAL: usize>(
75    strategy: &Strategy,
76    client: &ComputeClient<R>,
77    input: &MatmulInputHandleRef<'_, R>,
78    weight: &MatmulInputHandleRef<'_, R>,
79    bias: &Option<MatmulInputHandleRef<'_, R>>,
80    out: &TensorHandleRef<'_, R>,
81    args: ConvolutionArgs<N_SPATIAL>,
82    dtypes: MatmulElems,
83) -> Result<(), ConvSetupError> {
84    let conv = Convolution::new(client, input, weight, bias, out, args, dtypes);
85
86    match strategy {
87        Strategy::Simple {
88            read_strategy,
89            tile_kind,
90        } => with_tile_kind!(tile_kind, Accelerated, || match read_strategy {
91            ReadingStrategy::Cyclic => conv.launch::<SimpleSyncCyclicConv<Accelerated>>(),
92            ReadingStrategy::Strided => conv.launch::<SimpleSyncStridedConv<Accelerated>>(),
93            ReadingStrategy::Tilewise => conv.launch::<SimpleSyncTilewiseConv<Accelerated>>(),
94            ReadingStrategy::AsyncCyclic => conv.launch::<SimpleAsyncCyclicConv<Accelerated>>(),
95            ReadingStrategy::AsyncStrided => conv.launch::<SimpleAsyncStridedConv<Accelerated>>(),
96            ReadingStrategy::Tma => conv.launch::<SimpleAsyncTmaConv<Accelerated>>(),
97        }),
98    }
99}
100
101#[derive(new)]
102struct Convolution<'a, R: Runtime, const N_SPATIAL: usize> {
103    client: &'a ComputeClient<R>,
104    input: &'a MatmulInputHandleRef<'a, R>,
105    weight: &'a MatmulInputHandleRef<'a, R>,
106    bias: &'a Option<MatmulInputHandleRef<'a, R>>,
107    out: &'a TensorHandleRef<'a, R>,
108    args: ConvolutionArgs<N_SPATIAL>,
109    dtypes: MatmulElems,
110}
111
112impl<'a, R: Runtime, const N_SPATIAL: usize> Convolution<'a, R, N_SPATIAL> {
113    fn launch<Alg: Algorithm>(self) -> Result<(), ConvSetupError>
114    where
115        Alg::Args: ConcreteArgs<Alg::Routine>,
116    {
117        let ConvolutionArgs {
118            stride,
119            padding,
120            dilation,
121        } = self.args;
122
123        let dimensionality = match N_SPATIAL {
124            1 => Dimensionality::Dim1,
125            2 => Dimensionality::Dim2,
126            3 => Dimensionality::Dim3,
127            other => unimplemented!("Unsupported dimensionality {other}"),
128        };
129
130        launch_with_algorithm::<R, Alg>(
131            self.client,
132            self.input,
133            self.weight,
134            self.bias,
135            self.out,
136            (&stride, &padding, &dilation),
137            dimensionality,
138            &BlueprintStrategy::Inferred(Default::default()),
139            self.dtypes,
140        )
141    }
142}
143
144#[allow(clippy::too_many_arguments)]
145fn launch_with_algorithm<R: Runtime, Alg: Algorithm>(
146    client: &ComputeClient<R>,
147    input: &MatmulInputHandleRef<'_, R>,
148    weight: &MatmulInputHandleRef<'_, R>,
149    bias: &Option<MatmulInputHandleRef<'_, R>>,
150    out: &TensorHandleRef<'_, R>,
151    (stride, padding, dilation): (&[usize], &[usize], &[usize]),
152    dimensionality: Dimensionality,
153    blueprint_strategy: &BlueprintStrategy<RuntimeArgs, Alg::Routine>,
154    dtypes: MatmulElems,
155) -> Result<(), ConvSetupError>
156where
157    Alg::Args: ConcreteArgs<Alg::Routine>,
158{
159    let rank = input.data().shape.len();
160    let dim_c = rank - 1;
161
162    let n = input.data().shape[0];
163    let c = input.data().shape[dim_c];
164
165    let out_c = weight.data().shape[0];
166
167    let in_shape = &input.data().shape[1..dim_c];
168    let kernel_shape = &weight.data().shape[1..dim_c];
169    let out_shape = &out.shape[1..dim_c];
170
171    let op = ConvolutionOperation::Forward;
172
173    let input_data = Alg::into_tensor_handle(client, input.data(), dtypes.lhs_global, op)?;
174    let weight_data = Alg::into_tensor_handle(client, weight.data(), dtypes.rhs_global, op)?;
175
176    let mut input = *input;
177    let mut weight = *weight;
178
179    *input.data_mut() = input_data.as_ref();
180    *weight.data_mut() = weight_data.as_ref();
181
182    let problem = ConvolutionProblem {
183        m: n * out_shape.iter().product::<usize>(),
184        n: out_c,
185        k: c * kernel_shape.iter().product::<usize>(),
186        lhs_strides: input.data().strides.to_vec(),
187        rhs_strides: weight.data().strides.to_vec(),
188        lhs_layout: definition::MatrixLayout::RowMajor,
189        rhs_layout: definition::MatrixLayout::ColMajor,
190        kernel_size: kernel_shape.iter().map(|it| *it as u32).collect(),
191        stride: stride.iter().map(|it| *it as u32).collect(),
192        padding: padding.iter().map(|it| *it as i32).collect(),
193        dilation: dilation.iter().map(|it| *it as u32).collect(),
194
195        batches: n,
196        in_shape: in_shape.to_vec(),
197        out_shape: out_shape.to_vec(),
198        channels: c,
199        out_channels: out_c,
200
201        padded_channels: c,
202        operation: op,
203
204        dimensionality,
205        global_dtypes: dtypes.as_global_elems(),
206    };
207
208    launch_kernel::<R, Alg>(
209        client,
210        &input,
211        &weight,
212        bias,
213        out,
214        problem,
215        blueprint_strategy,
216        dtypes,
217    )
218}
219
220#[allow(clippy::result_large_err, clippy::too_many_arguments)]
221pub fn launch_kernel<R: Runtime, Alg: Algorithm>(
222    client: &ComputeClient<R>,
223    input: &MatmulInputHandleRef<'_, R>,
224    weight: &MatmulInputHandleRef<'_, R>,
225    bias: &Option<MatmulInputHandleRef<'_, R>>,
226    out: &TensorHandleRef<'_, R>,
227    problem: ConvolutionProblem,
228    blueprint_strategy: &BlueprintStrategy<RuntimeArgs, Alg::Routine>,
229    dtypes: MatmulElems,
230) -> Result<(), ConvSetupError>
231where
232    Alg::Args: ConcreteArgs<Alg::Routine>,
233{
234    // Shape/strides are treated as k-major, with the last dim always being the contiguous one.
235    // So for the sake of selecting a line size, the shape/strides are always row-major.
236    let line_sizes = AvailableLineSizes::from_type_sizes(
237        client,
238        input.data().elem_size,
239        weight.data().elem_size,
240        out.elem_size,
241    )
242    .filter_lhs_with_tensor(
243        input.data().strides,
244        input.data().shape,
245        MatrixLayout::RowMajor,
246    )
247    .filter_rhs_with_tensor(
248        weight.data().strides,
249        weight.data().shape,
250        MatrixLayout::RowMajor,
251    )
252    .filter_out_with_tensor(out.strides, out.shape);
253
254    let mut line_sizes = Alg::filter_line_sizes(line_sizes).pick_max()?;
255
256    // The large line size resulting from dequantizing ends up slower due to restrictions on
257    // algorithms. Use this as a quick and dirty fix.
258    if input.scale().is_some() {
259        line_sizes.lhs = 1;
260    }
261    if weight.scale().is_some() {
262        line_sizes.rhs = 1;
263    }
264
265    launch_kernel_concrete::<R, Alg::Args, Alg::Routine>(
266        client,
267        input,
268        weight,
269        bias,
270        out,
271        problem,
272        line_sizes,
273        blueprint_strategy,
274        &dtypes,
275    )
276}