Skip to main content

cubek_convolution/kernels/forward/
launch.rs

1use crate::{
2    AcceleratedTileKind, ReadingStrategy, algorithm::simple::*,
3    components::global::args::RuntimeArgs,
4};
5use crate::{
6    ConvolutionArgs, Strategy, components::ConvolutionOperation, forward::args::ConcreteArgs,
7};
8use crate::{
9    algorithm::Algorithm,
10    components::{ConvolutionProblem, Dimensionality},
11};
12use crate::{components::ConvSetupError, kernels::forward::selector::launch_kernel_concrete};
13use cubecl::{Runtime, client::ComputeClient, prelude::*};
14use cubek_matmul::routines::BlueprintStrategy;
15use cubek_matmul::{
16    components::tile::{cmma::CmmaMatmul, mma::MmaMatmul},
17    definition::{AvailableVectorSizes, MatmulElems},
18};
19use cubek_std::{
20    InputBinding,
21    {MatrixLayout, tile::Strided},
22};
23use derive_new::new;
24
25macro_rules! with_tile_kind {
26    ($kind: expr, $T: ident, $launch: expr) => {
27        match $kind {
28            AcceleratedTileKind::Cmma => {
29                type $T = CmmaMatmul<Option<Strided>>;
30                ($launch)()
31            }
32            AcceleratedTileKind::Mma => {
33                type $T = MmaMatmul<Strided, Strided, Option<Strided>>;
34                ($launch)()
35            }
36        }
37    };
38}
39
40/// Perform an n-dimensional convolution using the implicit GEMM (im2col) algorithm, using cubecl
41/// tiling matmul components, using the specified algorithm.
42///
43/// * `input` - The input feature map, layout should be [batches, depth, height, width, in_channels]
44/// * `weight` - The weights (filter) applied to each kernel, layout should be [out_channels, kernel_d, kernel_h, kernel_w, in_channels]
45/// * `out` - The output feature map, layout should be [batches, out_depth, out_height, out_width, out_channels]
46/// * `bias` - The bias added to each out channel
47/// * `options` - The options to use for the convolution
48#[allow(clippy::result_large_err, clippy::too_many_arguments)]
49pub fn launch_ref<R: Runtime, const N_SPATIAL: usize>(
50    strategy: &Strategy,
51    client: &ComputeClient<R>,
52    input: InputBinding<R>,
53    weight: InputBinding<R>,
54    bias: Option<InputBinding<R>>,
55    out: TensorBinding<R>,
56    args: ConvolutionArgs<N_SPATIAL>,
57    dtypes: MatmulElems,
58) -> Result<(), ConvSetupError> {
59    let conv = Convolution::new(client, input, weight, bias, out, args, dtypes);
60
61    match strategy {
62        Strategy::Simple {
63            read_strategy,
64            tile_kind,
65        } => with_tile_kind!(tile_kind, Accelerated, || match read_strategy {
66            ReadingStrategy::Cyclic => conv.launch::<SimpleSyncCyclicConv<Accelerated>>(),
67            ReadingStrategy::Strided => conv.launch::<SimpleSyncStridedConv<Accelerated>>(),
68            ReadingStrategy::Tilewise => conv.launch::<SimpleSyncTilewiseConv<Accelerated>>(),
69            ReadingStrategy::AsyncCyclic => conv.launch::<SimpleAsyncCyclicConv<Accelerated>>(),
70            ReadingStrategy::AsyncStrided => conv.launch::<SimpleAsyncStridedConv<Accelerated>>(),
71            ReadingStrategy::Tma => conv.launch::<SimpleAsyncTmaConv<Accelerated>>(),
72        }),
73    }
74}
75
76#[derive(new)]
77struct Convolution<'a, R: Runtime, const N_SPATIAL: usize> {
78    client: &'a ComputeClient<R>,
79    input: InputBinding<R>,
80    weight: InputBinding<R>,
81    bias: Option<InputBinding<R>>,
82    out: TensorBinding<R>,
83    args: ConvolutionArgs<N_SPATIAL>,
84    dtypes: MatmulElems,
85}
86
87impl<'a, R: Runtime, const N_SPATIAL: usize> Convolution<'a, R, N_SPATIAL> {
88    fn launch<Alg: Algorithm>(self) -> Result<(), ConvSetupError>
89    where
90        Alg::Args: ConcreteArgs<Alg::Routine>,
91    {
92        let ConvolutionArgs {
93            stride,
94            padding,
95            dilation,
96        } = self.args;
97
98        let dimensionality = match N_SPATIAL {
99            1 => Dimensionality::Dim1,
100            2 => Dimensionality::Dim2,
101            3 => Dimensionality::Dim3,
102            other => unimplemented!("Unsupported dimensionality {other}"),
103        };
104
105        launch_with_algorithm::<R, Alg>(
106            self.client,
107            self.input,
108            self.weight,
109            self.bias,
110            self.out,
111            (&stride, &padding, &dilation),
112            dimensionality,
113            &BlueprintStrategy::Inferred(Default::default()),
114            self.dtypes,
115        )
116    }
117}
118
119#[allow(clippy::too_many_arguments)]
120fn launch_with_algorithm<R: Runtime, Alg: Algorithm>(
121    client: &ComputeClient<R>,
122    input: InputBinding<R>,
123    weight: InputBinding<R>,
124    bias: Option<InputBinding<R>>,
125    out: TensorBinding<R>,
126    (stride, padding, dilation): (&[usize], &[usize], &[usize]),
127    dimensionality: Dimensionality,
128    blueprint_strategy: &BlueprintStrategy<RuntimeArgs, Alg::Routine>,
129    dtypes: MatmulElems,
130) -> Result<(), ConvSetupError>
131where
132    Alg::Args: ConcreteArgs<Alg::Routine>,
133{
134    let rank = input.data().shape.len();
135    let dim_c = rank - 1;
136
137    let n = input.data().shape[0];
138    let c = input.data().shape[dim_c];
139
140    let out_c = weight.data().shape[0];
141
142    let in_shape = &input.data().shape[1..dim_c];
143    let kernel_shape = &weight.data().shape[1..dim_c];
144    let out_shape = &out.shape[1..dim_c];
145
146    let op = ConvolutionOperation::Forward;
147
148    let input_data = Alg::correct_layout(client, input.clone().into_data(), dtypes.lhs_global, op)?;
149    let weight_data =
150        Alg::correct_layout(client, weight.clone().into_data(), dtypes.rhs_global, op)?;
151
152    let mut input = input.clone();
153    let mut weight = weight.clone();
154
155    *input.data_mut() = input_data;
156    *weight.data_mut() = weight_data;
157
158    let address_type = input
159        .required_address_type()
160        .max(weight.required_address_type())
161        .max(
162            bias.clone()
163                .map(|bias| bias.required_address_type())
164                .unwrap_or_default(),
165        )
166        .max(out.required_address_type(dtypes.acc_global.size()));
167
168    let problem = ConvolutionProblem {
169        m: n * out_shape.iter().product::<usize>(),
170        n: out_c,
171        k: c * kernel_shape.iter().product::<usize>(),
172        lhs_strides: input.data().strides.clone(),
173        rhs_strides: weight.data().strides.clone(),
174        lhs_layout: MatrixLayout::RowMajor,
175        rhs_layout: MatrixLayout::ColMajor,
176        kernel_size: kernel_shape.iter().map(|it| *it as u32).collect(),
177        stride: stride.iter().map(|it| *it as u32).collect(),
178        padding: padding.iter().map(|it| *it as i32).collect(),
179        dilation: dilation.iter().map(|it| *it as u32).collect(),
180
181        batches: n,
182        in_shape: in_shape.into(),
183        out_shape: out_shape.into(),
184        channels: c,
185        out_channels: out_c,
186
187        padded_channels: c,
188        operation: op,
189
190        dimensionality,
191        global_dtypes: dtypes.as_global_elems(),
192        address_type,
193    };
194
195    launch_kernel::<R, Alg>(
196        client,
197        input,
198        weight,
199        bias,
200        out,
201        problem,
202        blueprint_strategy,
203        dtypes,
204    )
205}
206
207#[allow(clippy::result_large_err, clippy::too_many_arguments)]
208pub fn launch_kernel<R: Runtime, Alg: Algorithm>(
209    client: &ComputeClient<R>,
210    input: InputBinding<R>,
211    weight: InputBinding<R>,
212    bias: Option<InputBinding<R>>,
213    out: TensorBinding<R>,
214    problem: ConvolutionProblem,
215    blueprint_strategy: &BlueprintStrategy<RuntimeArgs, Alg::Routine>,
216    dtypes: MatmulElems,
217) -> Result<(), ConvSetupError>
218where
219    Alg::Args: ConcreteArgs<Alg::Routine>,
220{
221    // Shape/strides are treated as k-major, with the last dim always being the contiguous one.
222    // So for the sake of selecting a vector size, the shape/strides are always row-major.
223    let vector_sizes = AvailableVectorSizes::from_type_sizes(
224        client,
225        input.data_elem_size(),
226        weight.data_elem_size(),
227        dtypes.acc_global.size(),
228    )
229    .filter_lhs_with_tensor(
230        &input.data().strides,
231        &input.data().shape,
232        MatrixLayout::RowMajor,
233    )
234    .filter_rhs_with_tensor(
235        &weight.data().strides,
236        &weight.data().shape,
237        MatrixLayout::RowMajor,
238    )
239    .filter_out_with_tensor(&out.strides, &out.shape);
240
241    let mut vector_sizes = Alg::filter_vector_sizes(vector_sizes).pick_max()?;
242
243    // The large vector size resulting from dequantizing ends up slower due to restrictions on
244    // algorithms. Use this as a quick and dirty fix.
245    if input.scale().is_some() {
246        vector_sizes.lhs = 1;
247    }
248    if weight.scale().is_some() {
249        vector_sizes.rhs = 1;
250    }
251
252    launch_kernel_concrete::<R, Alg::Args, Alg::Routine>(
253        client,
254        input,
255        weight,
256        bias,
257        out,
258        problem,
259        vector_sizes,
260        blueprint_strategy,
261        &dtypes,
262    )
263}