Skip to main content

cubek_convolution/kernels/forward/
launch.rs

1use crate::{
2    AcceleratedTileKind, ReadingStrategy, algorithm::simple::*,
3    components::global::args::RuntimeArgs,
4};
5use crate::{
6    ConvolutionArgs, Strategy, components::ConvolutionOperation, forward::args::ConcreteArgs,
7};
8use crate::{
9    algorithm::Algorithm,
10    components::{ConvolutionProblem, Dimensionality},
11};
12use crate::{components::ConvSetupError, kernels::forward::selector::launch_kernel_concrete};
13use cubecl::{Runtime, client::ComputeClient, prelude::*};
14use cubek_matmul::routines::{BlueprintStrategy, TilingArgs};
15use cubek_matmul::{
16    components::tile_matmul::DispatchTileMatmul,
17    definition::{AvailableVectorSizes, MatmulElems},
18    routines::Routine,
19};
20use cubek_std::{InputBinding, MatrixLayout};
21use derive_new::new;
22
23fn tile_kind_to_dispatch(kind: &AcceleratedTileKind) -> DispatchTileMatmul {
24    match kind {
25        AcceleratedTileKind::Cmma => DispatchTileMatmul::Cmma,
26        AcceleratedTileKind::Mma => DispatchTileMatmul::Mma,
27    }
28}
29
30/// Perform an n-dimensional convolution using the implicit GEMM (im2col) algorithm, using cubecl
31/// tiling matmul components, using the specified algorithm.
32///
33/// * `input` - The input feature map, layout should be [batches, depth, height, width, in_channels]
34/// * `weight` - The weights (filter) applied to each kernel, layout should be [out_channels, kernel_d, kernel_h, kernel_w, in_channels]
35/// * `out` - The output feature map, layout should be [batches, out_depth, out_height, out_width, out_channels]
36/// * `bias` - The bias added to each out channel
37/// * `options` - The options to use for the convolution
38#[allow(clippy::result_large_err, clippy::too_many_arguments)]
39pub fn launch_ref<R: Runtime, const N_SPATIAL: usize>(
40    strategy: &Strategy,
41    client: &ComputeClient<R>,
42    input: InputBinding<R>,
43    weight: InputBinding<R>,
44    bias: Option<InputBinding<R>>,
45    out: TensorBinding<R>,
46    args: ConvolutionArgs<N_SPATIAL>,
47    dtypes: MatmulElems,
48) -> Result<(), ConvSetupError> {
49    let conv = Convolution::new(client, input, weight, bias, out, args, dtypes);
50
51    match strategy {
52        Strategy::Simple {
53            read_strategy,
54            tile_kind,
55        } => {
56            let kind = tile_kind_to_dispatch(tile_kind);
57            match read_strategy {
58                ReadingStrategy::Cyclic => conv.launch::<SimpleSyncCyclicConv>(kind),
59                ReadingStrategy::Strided => conv.launch::<SimpleSyncStridedConv>(kind),
60                ReadingStrategy::Tilewise => conv.launch::<SimpleSyncTilewiseConv>(kind),
61                ReadingStrategy::AsyncCyclic => conv.launch::<SimpleAsyncCyclicConv>(kind),
62                ReadingStrategy::AsyncStrided => conv.launch::<SimpleAsyncStridedConv>(kind),
63                ReadingStrategy::Tma => conv.launch::<SimpleAsyncTmaConv>(kind),
64            }
65        }
66    }
67}
68
69#[derive(new)]
70struct Convolution<'a, R: Runtime, const N_SPATIAL: usize> {
71    client: &'a ComputeClient<R>,
72    input: InputBinding<R>,
73    weight: InputBinding<R>,
74    bias: Option<InputBinding<R>>,
75    out: TensorBinding<R>,
76    args: ConvolutionArgs<N_SPATIAL>,
77    dtypes: MatmulElems,
78}
79
80impl<'a, R: Runtime, const N_SPATIAL: usize> Convolution<'a, R, N_SPATIAL> {
81    fn launch<Alg: Algorithm>(self, tile_matmul: DispatchTileMatmul) -> Result<(), ConvSetupError>
82    where
83        Alg::Args: ConcreteArgs<Alg::Routine>,
84        <Alg::Routine as Routine<RuntimeArgs>>::Strategy: TilingArgs,
85    {
86        let ConvolutionArgs {
87            stride,
88            padding,
89            dilation,
90        } = self.args;
91
92        let dimensionality = match N_SPATIAL {
93            1 => Dimensionality::Dim1,
94            2 => Dimensionality::Dim2,
95            3 => Dimensionality::Dim3,
96            other => unimplemented!("Unsupported dimensionality {other}"),
97        };
98
99        let mut args = <Alg::Routine as Routine<RuntimeArgs>>::Strategy::default();
100        args.set_tile_matmul(tile_matmul);
101
102        launch_with_algorithm::<R, Alg>(
103            self.client,
104            self.input,
105            self.weight,
106            self.bias,
107            self.out,
108            (&stride, &padding, &dilation),
109            dimensionality,
110            &BlueprintStrategy::Inferred(args),
111            self.dtypes,
112        )
113    }
114}
115
116#[allow(clippy::too_many_arguments)]
117fn launch_with_algorithm<R: Runtime, Alg: Algorithm>(
118    client: &ComputeClient<R>,
119    input: InputBinding<R>,
120    weight: InputBinding<R>,
121    bias: Option<InputBinding<R>>,
122    out: TensorBinding<R>,
123    (stride, padding, dilation): (&[usize], &[usize], &[usize]),
124    dimensionality: Dimensionality,
125    blueprint_strategy: &BlueprintStrategy<RuntimeArgs, Alg::Routine>,
126    dtypes: MatmulElems,
127) -> Result<(), ConvSetupError>
128where
129    Alg::Args: ConcreteArgs<Alg::Routine>,
130{
131    let rank = input.data().shape.len();
132    let dim_c = rank - 1;
133
134    let n = input.data().shape[0];
135    let c = input.data().shape[dim_c];
136
137    let out_c = weight.data().shape[0];
138
139    let in_shape = &input.data().shape[1..dim_c];
140    let kernel_shape = &weight.data().shape[1..dim_c];
141    let out_shape = &out.shape[1..dim_c];
142
143    let op = ConvolutionOperation::Forward;
144
145    let input_data = Alg::correct_layout(client, input.clone().into_data(), dtypes.lhs_global, op)?;
146    let weight_data =
147        Alg::correct_layout(client, weight.clone().into_data(), dtypes.rhs_global, op)?;
148
149    let mut input = input.clone();
150    let mut weight = weight.clone();
151
152    *input.data_mut() = input_data;
153    *weight.data_mut() = weight_data;
154
155    let address_type = input
156        .required_address_type()
157        .max(weight.required_address_type())
158        .max(
159            bias.clone()
160                .map(|bias| bias.required_address_type())
161                .unwrap_or_default(),
162        )
163        .max(out.required_address_type(dtypes.acc_global.size()));
164
165    let problem = ConvolutionProblem {
166        m: n * out_shape.iter().product::<usize>(),
167        n: out_c,
168        k: c * kernel_shape.iter().product::<usize>(),
169        lhs_strides: input.data().strides.clone(),
170        rhs_strides: weight.data().strides.clone(),
171        lhs_layout: MatrixLayout::RowMajor,
172        rhs_layout: MatrixLayout::ColMajor,
173        kernel_size: kernel_shape.iter().map(|it| *it as u32).collect(),
174        stride: stride.iter().map(|it| *it as u32).collect(),
175        padding: padding.iter().map(|it| *it as i32).collect(),
176        dilation: dilation.iter().map(|it| *it as u32).collect(),
177
178        batches: n,
179        in_shape: in_shape.into(),
180        out_shape: out_shape.into(),
181        channels: c,
182        out_channels: out_c,
183
184        padded_channels: c,
185        operation: op,
186
187        dimensionality,
188        global_dtypes: dtypes.as_global_elems(),
189        address_type,
190    };
191
192    launch_kernel::<R, Alg>(
193        client,
194        input,
195        weight,
196        bias,
197        out,
198        problem,
199        blueprint_strategy,
200        dtypes,
201    )
202}
203
204#[allow(clippy::result_large_err, clippy::too_many_arguments)]
205pub fn launch_kernel<R: Runtime, Alg: Algorithm>(
206    client: &ComputeClient<R>,
207    input: InputBinding<R>,
208    weight: InputBinding<R>,
209    bias: Option<InputBinding<R>>,
210    out: TensorBinding<R>,
211    problem: ConvolutionProblem,
212    blueprint_strategy: &BlueprintStrategy<RuntimeArgs, Alg::Routine>,
213    dtypes: MatmulElems,
214) -> Result<(), ConvSetupError>
215where
216    Alg::Args: ConcreteArgs<Alg::Routine>,
217{
218    // Shape/strides are treated as k-major, with the last dim always being the contiguous one.
219    // So for the sake of selecting a vector size, the shape/strides are always row-major.
220    let vector_sizes = AvailableVectorSizes::from_type_sizes(
221        client,
222        input.data_elem_size(),
223        weight.data_elem_size(),
224        dtypes.acc_global.size(),
225    )
226    .filter_lhs_with_tensor(
227        &input.data().strides,
228        &input.data().shape,
229        MatrixLayout::RowMajor,
230    )
231    .filter_rhs_with_tensor(
232        &weight.data().strides,
233        &weight.data().shape,
234        MatrixLayout::RowMajor,
235    )
236    .filter_out_with_tensor(&out.strides, &out.shape);
237
238    let mut vector_sizes = Alg::filter_vector_sizes(vector_sizes).pick_max()?;
239
240    // The large vector size resulting from dequantizing ends up slower due to restrictions on
241    // algorithms. Use this as a quick and dirty fix.
242    if input.scale().is_some() {
243        vector_sizes.lhs = 1;
244    }
245    if weight.scale().is_some() {
246        vector_sizes.rhs = 1;
247    }
248
249    launch_kernel_concrete::<R, Alg::Args, Alg::Routine>(
250        client,
251        input,
252        weight,
253        bias,
254        out,
255        problem,
256        vector_sizes,
257        blueprint_strategy,
258        &dtypes,
259    )
260}