cubecl_convolution/
launch.rs

1use crate::components::ConvGemmConfig as _;
2use crate::{components::ConvSetupError, kernels::layered::selector::launch_kernel_concrete};
3use crate::{
4    components::{
5        ConvolutionProblem, Dimensionality,
6        global::args::{ConcreteInputsFactory, ConcreteOutputFactory},
7    },
8    kernels::layered::algorithm::Algorithm,
9};
10use cubecl_core::{Runtime, client::ComputeClient, prelude::*};
11use cubecl_matmul::components::{
12    self, AvailableLineSizes, MatmulElems, MatmulIdent, MatmulSelection,
13};
14use cubecl_matmul::{
15    MatmulInputHandleRef,
16    components::{InputArg, OutputArg},
17};
18
19#[derive(Clone)]
20pub struct ConvolutionArgs<const N_SPATIAL: usize> {
21    pub stride: [usize; N_SPATIAL],
22    pub padding: [usize; N_SPATIAL],
23    pub dilation: [usize; N_SPATIAL],
24}
25
26/// Perform an n-dimensional convolution using the implicit GEMM (im2col) algorithm, using cubecl
27/// tiling matmul components, using the specified algorithm.
28///
29/// * `input` - The input feature map, layout should be [batches, depth, height, width, in_channels]
30/// * `weight` - The weights (filter) applied to each kernel, layout should be [out_channels, kernel_d, kernel_h, kernel_w, in_channels]
31/// * `out` - The output feature map, layout should be [batches, out_depth, out_height, out_width, out_channels]
32/// * `bias` - The bias added to each out channel
33/// * `options` - The options to use for the convolution
34#[allow(clippy::result_large_err)]
35pub fn launch_conv<R: Runtime, Alg: Algorithm, const N_SPATIAL: usize>(
36    client: &ComputeClient<R::Server>,
37    input: &MatmulInputHandleRef<'_, R>,
38    weight: &MatmulInputHandleRef<'_, R>,
39    bias: &Option<TensorHandleRef<'_, R>>,
40    out: &TensorHandleRef<'_, R>,
41    args: ConvolutionArgs<N_SPATIAL>,
42    dtypes: MatmulElems,
43) -> Result<(), ConvSetupError>
44where
45    InputArg<Alg::Args>: ConcreteInputsFactory,
46    OutputArg<Alg::Args>: ConcreteOutputFactory,
47{
48    let ConvolutionArgs {
49        stride,
50        padding,
51        dilation,
52    } = args;
53
54    let dimensionality = match N_SPATIAL {
55        1 => Dimensionality::Dim1,
56        2 => Dimensionality::Dim2,
57        3 => Dimensionality::Dim3,
58        other => unimplemented!("Unsupported dimensionality {other}"),
59    };
60
61    launch::<R, Alg>(
62        client,
63        input,
64        weight,
65        bias,
66        out,
67        (&stride, &padding, &dilation),
68        dimensionality,
69        dtypes,
70    )
71}
72
73#[allow(clippy::too_many_arguments)]
74fn launch<R: Runtime, Alg: Algorithm>(
75    client: &ComputeClient<R::Server>,
76    input: &MatmulInputHandleRef<'_, R>,
77    weight: &MatmulInputHandleRef<'_, R>,
78    bias: &Option<TensorHandleRef<'_, R>>,
79    out: &TensorHandleRef<'_, R>,
80    (stride, padding, dilation): (&[usize], &[usize], &[usize]),
81    dimensionality: Dimensionality,
82    mut dtypes: MatmulElems,
83) -> Result<(), ConvSetupError>
84where
85    InputArg<Alg::Args>: ConcreteInputsFactory,
86    OutputArg<Alg::Args>: ConcreteOutputFactory,
87{
88    let rank = input.data().shape.len();
89    let dim_c = rank - 1;
90
91    let n = input.data().shape[0];
92    let c = input.data().shape[dim_c];
93
94    let out_c = weight.data().shape[0];
95
96    let in_shape = &input.data().shape[1..dim_c];
97    let kernel_shape = &weight.data().shape[1..dim_c];
98    let out_shape = &out.shape[1..dim_c];
99
100    let input_data =
101        Alg::into_tensor_handle::<R>(client, input.data(), MatmulIdent::Lhs, dtypes.lhs_global);
102    let weight_data =
103        Alg::into_tensor_handle::<R>(client, weight.data(), MatmulIdent::Rhs, dtypes.rhs_global);
104
105    let mut input = *input;
106    let mut weight = *weight;
107
108    *input.data_mut() = input_data.as_ref();
109    *weight.data_mut() = weight_data.as_ref();
110
111    let plane_dim = client.properties().hardware.plane_size_max;
112
113    let problem = ConvolutionProblem {
114        m: n * out_shape.iter().product::<usize>(),
115        n: out_c,
116        k: c * kernel_shape.iter().product::<usize>(),
117        lhs_layout: components::MatrixLayout::RowMajor,
118        rhs_layout: components::MatrixLayout::ColMajor,
119        kernel_size: kernel_shape.iter().map(|it| *it as u32).collect(),
120        stride: stride.iter().map(|it| *it as u32).collect(),
121        padding: padding.iter().map(|it| *it as i32).collect(),
122        dilation: dilation.iter().map(|it| *it as u32).collect(),
123
124        batches: n,
125        shape: in_shape.to_vec(),
126        out_shape: out_shape.to_vec(),
127        channels: c,
128
129        dimensionality,
130    };
131
132    let selection = Alg::selection::<R>(client, &problem, plane_dim, &mut dtypes)?;
133
134    launch_kernel::<R, Alg>(
135        client, &input, &weight, bias, out, problem, selection, dtypes,
136    )
137}
138
139#[allow(clippy::result_large_err, clippy::too_many_arguments)]
140pub fn launch_kernel<R: Runtime, Alg: Algorithm>(
141    client: &ComputeClient<R::Server>,
142    input: &MatmulInputHandleRef<'_, R>,
143    weight: &MatmulInputHandleRef<'_, R>,
144    bias: &Option<TensorHandleRef<'_, R>>,
145    out: &TensorHandleRef<'_, R>,
146    problem: ConvolutionProblem,
147    selection: MatmulSelection,
148    dtypes: MatmulElems,
149) -> Result<(), ConvSetupError>
150where
151    InputArg<Alg::Args>: ConcreteInputsFactory,
152    OutputArg<Alg::Args>: ConcreteOutputFactory,
153{
154    let line_sizes = AvailableLineSizes::from_type_sizes::<R>(
155        input.data().elem_size,
156        weight.data().elem_size,
157        out.elem_size,
158    )
159    .filter_lhs_with_tensor(input.data().strides, input.data().shape, problem.lhs_layout)
160    .filter_rhs_with_tensor(
161        weight.data().strides,
162        weight.data().shape,
163        problem.rhs_layout,
164    )
165    .filter_out_with_tensor(out.strides, out.shape);
166
167    let line_sizes = Alg::filter_line_sizes(line_sizes).pick_max()?;
168
169    let config = Alg::setup::<R>(client, &problem, &selection, &line_sizes, &dtypes)?;
170
171    let line_sizes = config.line_sizes();
172
173    launch_kernel_concrete::<R, Alg>(
174        client, input, weight, bias, out, problem, line_sizes, selection, &dtypes,
175    )
176}