cubecl_convolution/
launch.rs

1use std::any::TypeId;
2
3use cubecl_core::{Runtime, client::ComputeClient, prelude::*};
4use cubecl_matmul::components::global::GlobalConfig;
5use half::f16;
6
7use crate::ConvGemmConfig;
8use crate::base::ConvolutionLaunch;
9use cubecl_matmul::components::global::args::{ConcreteOutputFactory, MatmulArgs};
10use cubecl_matmul::components::{
11    self, AvailableLineSizes, InputIdent, MatmulPrecision, MatmulSelection,
12};
13
14use super::{
15    ConvLaunchError,
16    algorithm::Algorithm,
17    args::ConvInputsLaunch,
18    base::{ConvolutionProblem, Dimensionality},
19};
20
21type Input<Alg, MP> = <<Alg as Algorithm>::Args as MatmulArgs>::Input<<MP as MatmulPrecision>::EI>;
22type Output<Alg, MP> =
23    <<Alg as Algorithm>::Args as MatmulArgs>::Output<<MP as MatmulPrecision>::EO>;
24
25#[derive(Clone)]
26pub struct ConvolutionArgs<const N_SPATIAL: usize> {
27    pub stride: [usize; N_SPATIAL],
28    pub padding: [usize; N_SPATIAL],
29    pub dilation: [usize; N_SPATIAL],
30}
31
32/// Perform an n-dimensional convolution using the implicit GEMM (im2col) algorithm, using cubecl
33/// tiling matmul components, using the specified algorithm.
34///
35/// * `input` - The input feature map, layout should be [batches, depth, height, width, in_channels]
36/// * `weight` - The weights (filter) applied to each kernel, layout should be [out_channels, kernel_d, kernel_h, kernel_w, in_channels]
37/// * `out` - The output feature map, layout should be [batches, out_depth, out_height, out_width, out_channels]
38/// * `bias` - The bias added to each out channel
39/// * `options` - The options to use for the convolution
40#[allow(clippy::result_large_err)]
41pub fn launch_conv<R: Runtime, MP: MatmulPrecision, Alg: Algorithm, const N_SPATIAL: usize>(
42    client: &ComputeClient<R::Server, R::Channel>,
43    input: &TensorHandleRef<'_, R>,
44    weight: &TensorHandleRef<'_, R>,
45    bias: &Option<TensorHandleRef<'_, R>>,
46    out: &TensorHandleRef<'_, R>,
47    args: ConvolutionArgs<N_SPATIAL>,
48) -> Result<(), ConvLaunchError>
49where
50    Input<Alg, MP>: ConvInputsLaunch,
51    Output<Alg, MP>: ConcreteOutputFactory,
52{
53    let ConvolutionArgs {
54        stride,
55        padding,
56        dilation,
57    } = args;
58
59    let dimensionality = match N_SPATIAL {
60        1 => Dimensionality::Dim1,
61        2 => Dimensionality::Dim2,
62        3 => Dimensionality::Dim3,
63        other => unimplemented!("Unsupported dimensionality {other}"),
64    };
65
66    launch::<R, MP, Alg>(
67        client,
68        input,
69        weight,
70        bias,
71        out,
72        (&stride, &padding, &dilation),
73        dimensionality,
74    )
75}
76
77fn launch<R: Runtime, MP: MatmulPrecision, Alg: Algorithm>(
78    client: &ComputeClient<R::Server, R::Channel>,
79    input: &TensorHandleRef<'_, R>,
80    weight: &TensorHandleRef<'_, R>,
81    bias: &Option<TensorHandleRef<'_, R>>,
82    out: &TensorHandleRef<'_, R>,
83    (stride, padding, dilation): (&[usize], &[usize], &[usize]),
84    dimensionality: Dimensionality,
85) -> Result<(), ConvLaunchError>
86where
87    Input<Alg, MP>: ConvInputsLaunch,
88    Output<Alg, MP>: ConcreteOutputFactory,
89{
90    let rank = input.shape.len();
91    let dim_c = rank - 1;
92
93    let n = input.shape[0];
94    let c = input.shape[dim_c];
95
96    let out_c = weight.shape[0];
97
98    let in_shape = &input.shape[1..dim_c];
99    let kernel_shape = &weight.shape[1..dim_c];
100    let out_shape = &out.shape[1..dim_c];
101
102    let input = Alg::into_tensor_handle::<R, MP::EI>(client, input, InputIdent::Lhs);
103    let weight = Alg::into_tensor_handle::<R, MP::EI>(client, weight, InputIdent::Rhs);
104
105    let plane_dim = client.properties().hardware.plane_size_max;
106
107    let problem = ConvolutionProblem {
108        m: n * out_shape.iter().product::<usize>(),
109        n: out_c,
110        k: c * kernel_shape.iter().product::<usize>(),
111        lhs_layout: components::MatrixLayout::RowMajor,
112        rhs_layout: components::MatrixLayout::ColMajor,
113        kernel_size: kernel_shape.iter().map(|it| *it as u32).collect(),
114        stride: stride.iter().map(|it| *it as u32).collect(),
115        padding: padding.iter().map(|it| *it as i32).collect(),
116        dilation: dilation.iter().map(|it| *it as u32).collect(),
117
118        batches: n,
119        shape: in_shape.to_vec(),
120        out_shape: out_shape.to_vec(),
121        channels: c,
122
123        dimensionality,
124    };
125
126    let selection = Alg::selection::<R>(
127        client,
128        &problem,
129        plane_dim,
130        MP::ES::as_elem_native_unchecked(),
131        MP::EA::as_elem_native_unchecked(),
132    );
133
134    let launch = if TypeId::of::<MP::EI>() == TypeId::of::<f32>() {
135        if tf32::is_supported(client) {
136            launch_kernel::<R, (MP::EI, tf32, f32, MP::EO), Alg>
137        } else {
138            launch_kernel::<R, (MP::EI, f16, f32, MP::EO), Alg>
139        }
140    } else {
141        launch_kernel::<R, MP, Alg>
142    };
143
144    launch(
145        client,
146        &input.as_ref(),
147        &weight.as_ref(),
148        bias,
149        out,
150        problem,
151        selection,
152    )
153}
154
155#[allow(clippy::result_large_err, clippy::too_many_arguments)]
156pub fn launch_kernel<R: Runtime, MP: MatmulPrecision, Alg: Algorithm>(
157    client: &ComputeClient<R::Server, R::Channel>,
158    input: &TensorHandleRef<'_, R>,
159    weight: &TensorHandleRef<'_, R>,
160    bias: &Option<TensorHandleRef<'_, R>>,
161    out: &TensorHandleRef<'_, R>,
162    problem: ConvolutionProblem,
163    selection: MatmulSelection,
164) -> Result<(), ConvLaunchError>
165where
166    Input<Alg, MP>: ConvInputsLaunch,
167    Output<Alg, MP>: ConcreteOutputFactory,
168{
169    let rank = out.shape.len();
170    let dim_c = rank - 1;
171
172    // Reshape out to (M, N)
173    let out_shape = [out.shape[0..dim_c].iter().product(), out.shape[dim_c]];
174    let out_strides = [out.strides[rank - 2], out.strides[dim_c]];
175
176    let out = unsafe {
177        TensorHandleRef::from_raw_parts(out.handle, &out_strides, &out_shape, out.elem_size)
178    };
179
180    let line_sizes = AvailableLineSizes::from_elem_types::<R>(
181        &MP::EI::as_elem_native_unchecked(),
182        &MP::EO::as_elem_native_unchecked(),
183    )
184    .filter_lhs_with_tensor(input.strides, input.shape, problem.lhs_layout)
185    .filter_rhs_with_tensor(weight.strides, weight.shape, problem.rhs_layout)
186    .filter_out_with_tensor(out.strides, out.shape);
187
188    let line_sizes = Alg::filter_line_sizes(line_sizes).pick_max()?;
189
190    let config = Alg::setup::<R, MP>(client, &problem, &selection, &line_sizes)?;
191
192    let line_sizes = config.line_sizes();
193
194    let input = <Input<Alg, MP> as ConvInputsLaunch>::create(
195        input,
196        weight,
197        &selection,
198        &problem,
199        &line_sizes,
200    );
201    let output = <Output<Alg, MP> as ConcreteOutputFactory>::create(
202        &out,
203        &selection,
204        &problem.as_matmul_problem(),
205        &line_sizes,
206    );
207    let bias = bias.as_ref().map(|bias| bias.as_tensor_arg(line_sizes.out));
208
209    unsafe {
210        Alg::GlobalConvolution::launch_unchecked::<(MP, Alg::Args), R>(
211            client,
212            config.cube_dim(),
213            Alg::cube_count(&selection, &problem),
214            input,
215            bias,
216            output,
217            &problem,
218            config,
219        );
220    }
221
222    Ok(())
223}