cubecl_convolution/
launch.rs

1use crate::{
2    components::{ConvGemmConfig as _, ConvSetupError},
3    kernels::layered::selector::launch_kernel_concrete,
4};
5use crate::{
6    components::{
7        ConvolutionProblem, Dimensionality,
8        global::args::{ConcreteInputsFactory, ConcreteOutputFactory},
9    },
10    kernels::layered::algorithm::Algorithm,
11};
12use cubecl_core::{Runtime, client::ComputeClient, prelude::*};
13use cubecl_matmul::components::{
14    self, AvailableLineSizes, MatmulElems, MatmulIdent, MatmulSelection,
15};
16use cubecl_matmul::{
17    MatmulInputHandleRef,
18    components::{InputArg, OutputArg},
19};
20
21#[derive(Clone)]
22pub struct ConvolutionArgs<const N_SPATIAL: usize> {
23    pub stride: [usize; N_SPATIAL],
24    pub padding: [usize; N_SPATIAL],
25    pub dilation: [usize; N_SPATIAL],
26}
27
28/// Perform an n-dimensional convolution using the implicit GEMM (im2col) algorithm, using cubecl
29/// tiling matmul components, using the specified algorithm.
30///
31/// * `input` - The input feature map, layout should be [batches, depth, height, width, in_channels]
32/// * `weight` - The weights (filter) applied to each kernel, layout should be [out_channels, kernel_d, kernel_h, kernel_w, in_channels]
33/// * `out` - The output feature map, layout should be [batches, out_depth, out_height, out_width, out_channels]
34/// * `bias` - The bias added to each out channel
35/// * `options` - The options to use for the convolution
36#[allow(clippy::result_large_err)]
37pub fn launch_conv<R: Runtime, Alg: Algorithm, const N_SPATIAL: usize>(
38    client: &ComputeClient<R::Server>,
39    input: &MatmulInputHandleRef<'_, R>,
40    weight: &MatmulInputHandleRef<'_, R>,
41    bias: &Option<TensorHandleRef<'_, R>>,
42    out: &TensorHandleRef<'_, R>,
43    args: ConvolutionArgs<N_SPATIAL>,
44    dtypes: MatmulElems,
45) -> Result<(), ConvSetupError>
46where
47    InputArg<Alg::Args>: ConcreteInputsFactory,
48    OutputArg<Alg::Args>: ConcreteOutputFactory,
49{
50    let ConvolutionArgs {
51        stride,
52        padding,
53        dilation,
54    } = args;
55
56    let dimensionality = match N_SPATIAL {
57        1 => Dimensionality::Dim1,
58        2 => Dimensionality::Dim2,
59        3 => Dimensionality::Dim3,
60        other => unimplemented!("Unsupported dimensionality {other}"),
61    };
62
63    launch::<R, Alg>(
64        client,
65        input,
66        weight,
67        bias,
68        out,
69        (&stride, &padding, &dilation),
70        dimensionality,
71        dtypes,
72    )
73}
74
75#[allow(clippy::too_many_arguments)]
76fn launch<R: Runtime, Alg: Algorithm>(
77    client: &ComputeClient<R::Server>,
78    input: &MatmulInputHandleRef<'_, R>,
79    weight: &MatmulInputHandleRef<'_, R>,
80    bias: &Option<TensorHandleRef<'_, R>>,
81    out: &TensorHandleRef<'_, R>,
82    (stride, padding, dilation): (&[usize], &[usize], &[usize]),
83    dimensionality: Dimensionality,
84    mut dtypes: MatmulElems,
85) -> Result<(), ConvSetupError>
86where
87    InputArg<Alg::Args>: ConcreteInputsFactory,
88    OutputArg<Alg::Args>: ConcreteOutputFactory,
89{
90    let rank = input.data().shape.len();
91    let dim_c = rank - 1;
92
93    let n = input.data().shape[0];
94    let c = input.data().shape[dim_c];
95
96    let out_c = weight.data().shape[0];
97
98    let in_shape = &input.data().shape[1..dim_c];
99    let kernel_shape = &weight.data().shape[1..dim_c];
100    let out_shape = &out.shape[1..dim_c];
101
102    let input_data =
103        Alg::into_tensor_handle::<R>(client, input.data(), MatmulIdent::Lhs, dtypes.lhs_global);
104    let weight_data =
105        Alg::into_tensor_handle::<R>(client, weight.data(), MatmulIdent::Rhs, dtypes.rhs_global);
106
107    let mut input = *input;
108    let mut weight = *weight;
109
110    *input.data_mut() = input_data.as_ref();
111    *weight.data_mut() = weight_data.as_ref();
112
113    let plane_dim = client.properties().hardware.plane_size_max;
114
115    let problem = ConvolutionProblem {
116        m: n * out_shape.iter().product::<usize>(),
117        n: out_c,
118        k: c * kernel_shape.iter().product::<usize>(),
119        lhs_layout: components::MatrixLayout::RowMajor,
120        rhs_layout: components::MatrixLayout::ColMajor,
121        kernel_size: kernel_shape.iter().map(|it| *it as u32).collect(),
122        stride: stride.iter().map(|it| *it as u32).collect(),
123        padding: padding.iter().map(|it| *it as i32).collect(),
124        dilation: dilation.iter().map(|it| *it as u32).collect(),
125
126        batches: n,
127        shape: in_shape.to_vec(),
128        out_shape: out_shape.to_vec(),
129        channels: c,
130
131        dimensionality,
132    };
133
134    let selection = Alg::selection::<R>(client, &problem, plane_dim, &mut dtypes)?;
135
136    launch_kernel::<R, Alg>(
137        client, &input, &weight, bias, out, problem, selection, dtypes,
138    )
139}
140
141#[allow(clippy::result_large_err, clippy::too_many_arguments)]
142pub fn launch_kernel<R: Runtime, Alg: Algorithm>(
143    client: &ComputeClient<R::Server>,
144    input: &MatmulInputHandleRef<'_, R>,
145    weight: &MatmulInputHandleRef<'_, R>,
146    bias: &Option<TensorHandleRef<'_, R>>,
147    out: &TensorHandleRef<'_, R>,
148    problem: ConvolutionProblem,
149    selection: MatmulSelection,
150    dtypes: MatmulElems,
151) -> Result<(), ConvSetupError>
152where
153    InputArg<Alg::Args>: ConcreteInputsFactory,
154    OutputArg<Alg::Args>: ConcreteOutputFactory,
155{
156    let line_sizes = AvailableLineSizes::from_type_sizes::<R>(
157        input.data().elem_size,
158        weight.data().elem_size,
159        out.elem_size,
160    )
161    .filter_lhs_with_tensor(input.data().strides, input.data().shape, problem.lhs_layout)
162    .filter_rhs_with_tensor(
163        weight.data().strides,
164        weight.data().shape,
165        problem.rhs_layout,
166    )
167    .filter_out_with_tensor(out.strides, out.shape);
168
169    let line_sizes = Alg::filter_line_sizes(line_sizes).pick_max()?;
170
171    let config = Alg::setup::<R>(client, &problem, &selection, &line_sizes, &dtypes)?;
172
173    let line_sizes = config.line_sizes();
174
175    launch_kernel_concrete::<R, Alg>(
176        client, input, weight, bias, out, problem, line_sizes, selection, &dtypes,
177    )
178}