cubecl_convolution/
launch.rs

1use crate::components::ConvGemmConfig as _;
2use crate::{components::ConvSetupError, kernels::layered::selector::launch_kernel_concrete};
3use crate::{
4    components::{
5        ConvolutionProblem, Dimensionality,
6        global::args::{ConcreteInputsFactory, ConcreteOutputFactory},
7    },
8    kernels::layered::algorithm::Algorithm,
9};
10use cubecl_core::{Runtime, client::ComputeClient, prelude::*};
11use cubecl_matmul::components::{self, AvailableLineSizes, MatmulElems, MatrixLayout};
12use cubecl_matmul::{
13    MatmulInputHandleRef,
14    components::{InputArg, OutputArg},
15};
16
17#[derive(Clone)]
18pub struct ConvolutionArgs<const N_SPATIAL: usize> {
19    pub stride: [usize; N_SPATIAL],
20    pub padding: [usize; N_SPATIAL],
21    pub dilation: [usize; N_SPATIAL],
22}
23
24/// Perform an n-dimensional convolution using the implicit GEMM (im2col) algorithm, using cubecl
25/// tiling matmul components, using the specified algorithm.
26///
27/// * `input` - The input feature map, layout should be [batches, depth, height, width, in_channels]
28/// * `weight` - The weights (filter) applied to each kernel, layout should be [out_channels, kernel_d, kernel_h, kernel_w, in_channels]
29/// * `out` - The output feature map, layout should be [batches, out_depth, out_height, out_width, out_channels]
30/// * `bias` - The bias added to each out channel
31/// * `options` - The options to use for the convolution
32#[allow(clippy::result_large_err)]
33pub fn launch_conv<R: Runtime, Alg: Algorithm, const N_SPATIAL: usize>(
34    client: &ComputeClient<R>,
35    input: &MatmulInputHandleRef<'_, R>,
36    weight: &MatmulInputHandleRef<'_, R>,
37    bias: &Option<TensorHandleRef<'_, R>>,
38    out: &TensorHandleRef<'_, R>,
39    args: ConvolutionArgs<N_SPATIAL>,
40    dtypes: MatmulElems,
41) -> Result<(), ConvSetupError>
42where
43    InputArg<Alg::Args>: ConcreteInputsFactory,
44    OutputArg<Alg::Args>: ConcreteOutputFactory,
45{
46    let ConvolutionArgs {
47        stride,
48        padding,
49        dilation,
50    } = args;
51
52    let dimensionality = match N_SPATIAL {
53        1 => Dimensionality::Dim1,
54        2 => Dimensionality::Dim2,
55        3 => Dimensionality::Dim3,
56        other => unimplemented!("Unsupported dimensionality {other}"),
57    };
58
59    launch::<R, Alg>(
60        client,
61        input,
62        weight,
63        bias,
64        out,
65        (&stride, &padding, &dilation),
66        dimensionality,
67        dtypes,
68    )
69}
70
71#[allow(clippy::too_many_arguments)]
72fn launch<R: Runtime, Alg: Algorithm>(
73    client: &ComputeClient<R>,
74    input: &MatmulInputHandleRef<'_, R>,
75    weight: &MatmulInputHandleRef<'_, R>,
76    bias: &Option<TensorHandleRef<'_, R>>,
77    out: &TensorHandleRef<'_, R>,
78    (stride, padding, dilation): (&[usize], &[usize], &[usize]),
79    dimensionality: Dimensionality,
80    dtypes: MatmulElems,
81) -> Result<(), ConvSetupError>
82where
83    InputArg<Alg::Args>: ConcreteInputsFactory,
84    OutputArg<Alg::Args>: ConcreteOutputFactory,
85{
86    let rank = input.data().shape.len();
87    let dim_c = rank - 1;
88
89    let n = input.data().shape[0];
90    let c = input.data().shape[dim_c];
91
92    let out_c = weight.data().shape[0];
93
94    let in_shape = &input.data().shape[1..dim_c];
95    let kernel_shape = &weight.data().shape[1..dim_c];
96    let out_shape = &out.shape[1..dim_c];
97
98    let input_data = Alg::into_tensor_handle(client, input.data(), *dtypes.lhs_global)?;
99    let weight_data = Alg::into_tensor_handle(client, weight.data(), *dtypes.rhs_global)?;
100
101    let mut input = *input;
102    let mut weight = *weight;
103
104    *input.data_mut() = input_data.as_ref();
105    *weight.data_mut() = weight_data.as_ref();
106
107    let problem = ConvolutionProblem {
108        m: n * out_shape.iter().product::<usize>(),
109        n: out_c,
110        k: c * kernel_shape.iter().product::<usize>(),
111        lhs_strides: input.data().strides.to_vec(),
112        rhs_strides: weight.data().strides.to_vec(),
113        lhs_layout: components::MatrixLayout::RowMajor,
114        rhs_layout: components::MatrixLayout::ColMajor,
115        kernel_size: kernel_shape.iter().map(|it| *it as u32).collect(),
116        stride: stride.iter().map(|it| *it as u32).collect(),
117        padding: padding.iter().map(|it| *it as i32).collect(),
118        dilation: dilation.iter().map(|it| *it as u32).collect(),
119
120        batches: n,
121        shape: in_shape.to_vec(),
122        out_shape: out_shape.to_vec(),
123        channels: c,
124
125        dimensionality,
126    };
127
128    launch_kernel::<R, Alg>(client, &input, &weight, bias, out, problem, dtypes)
129}
130
131#[allow(clippy::result_large_err, clippy::too_many_arguments)]
132pub fn launch_kernel<R: Runtime, Alg: Algorithm>(
133    client: &ComputeClient<R>,
134    input: &MatmulInputHandleRef<'_, R>,
135    weight: &MatmulInputHandleRef<'_, R>,
136    bias: &Option<TensorHandleRef<'_, R>>,
137    out: &TensorHandleRef<'_, R>,
138    problem: ConvolutionProblem,
139    mut dtypes: MatmulElems,
140) -> Result<(), ConvSetupError>
141where
142    InputArg<Alg::Args>: ConcreteInputsFactory,
143    OutputArg<Alg::Args>: ConcreteOutputFactory,
144{
145    let plane_dim = client.properties().hardware.plane_size_max;
146    // Shape/strides are treated as k-major, with the last dim always being the contiguous one.
147    // So for the sake of selecting a line size, the shape/strides are always row-major.
148    let line_sizes = AvailableLineSizes::from_type_sizes(
149        client,
150        input.data().elem_size,
151        weight.data().elem_size,
152        out.elem_size,
153    )
154    .filter_lhs_with_tensor(
155        input.data().strides,
156        input.data().shape,
157        MatrixLayout::RowMajor,
158    )
159    .filter_rhs_with_tensor(
160        weight.data().strides,
161        weight.data().shape,
162        MatrixLayout::RowMajor,
163    )
164    .filter_out_with_tensor(out.strides, out.shape);
165
166    let line_sizes = Alg::filter_line_sizes(line_sizes).pick_max()?;
167
168    let selection = Alg::selection(client, &problem, plane_dim, &line_sizes, &mut dtypes)?;
169
170    let config = Alg::setup(client, &problem, &selection, &line_sizes, &dtypes)?;
171
172    let line_sizes = config.line_sizes();
173
174    launch_kernel_concrete::<R, Alg>(
175        client, input, weight, bias, out, problem, line_sizes, selection, &dtypes,
176    )
177}