cubecl_linalg/convolution/
launch.rs

1use std::any::TypeId;
2
3use cubecl_core::{Runtime, client::ComputeClient, prelude::*, tensor_line_size_parallel};
4use half::f16;
5
6use crate::matmul::{
7    components::global::args::{ConcreteOutputFactory, MatmulArgs},
8    kernels::MatmulLaunchError,
9};
10use crate::{
11    convolution::base::ConvolutionLaunch,
12    matmul::components::{self, InputIdent, MatmulPrecision, MatmulSelection},
13};
14
15use super::{
16    ConvLaunchError,
17    algorithm::{Algorithm, StageInput},
18    args::ConvInputsLaunch,
19    base::ConvolutionProblem,
20    selection::select_matmul,
21};
22
23type Input<Alg, MP> = <<Alg as Algorithm>::Args as MatmulArgs>::Input<<MP as MatmulPrecision>::EI>;
24type Output<Alg, MP> =
25    <<Alg as Algorithm>::Args as MatmulArgs>::Output<<MP as MatmulPrecision>::EO>;
26
27#[derive(Clone)]
28pub struct ConvolutionArgs {
29    pub stride: (usize, usize),
30    pub padding: (usize, usize),
31    pub dilation: (usize, usize),
32}
33
34/// Perform a 2D convolution using the implicit GEMM (im2col) algorithm, using cubecl tiling matmul
35/// components, using the specified algorithm.
36///
37/// * `input` - The input feature map, layout should be [batches, height, width, in_channels]
38/// * `weight` - The weights (filter) applied to each kernel, layout should be [out_channels, kernel_h, kernel_w, in_channels]
39/// * `out` - The output feature map, layout should be [batches, out_height, out_width, out_channels]
40/// * `bias` - The bias added to each out channel
41/// * `options` - The options to use for the convolution
42#[allow(clippy::result_large_err)]
43pub fn launch_conv2d_nhwc<R: Runtime, MP: MatmulPrecision, Alg: Algorithm>(
44    client: &ComputeClient<R::Server, R::Channel>,
45    input: &TensorHandleRef<'_, R>,
46    weight: &TensorHandleRef<'_, R>,
47    bias: &Option<TensorHandleRef<'_, R>>,
48    out: &TensorHandleRef<'_, R>,
49    args: ConvolutionArgs,
50) -> Result<(), ConvLaunchError>
51where
52    Input<Alg, MP>: ConvInputsLaunch,
53    Output<Alg, MP>: ConcreteOutputFactory,
54{
55    let ConvolutionArgs {
56        stride,
57        padding,
58        dilation,
59    } = args;
60
61    let [n, h, w, c] = input.shape.try_into().unwrap();
62    let [out_c, kh, kw, _] = weight.shape.try_into().unwrap();
63    let out_h = out.shape[1];
64    let out_w = out.shape[2];
65
66    let input = Alg::into_tensor_handle::<R, MP::EI>(client, input, InputIdent::Lhs);
67    let weight = Alg::into_tensor_handle::<R, MP::EI>(client, weight, InputIdent::Rhs);
68
69    let ei_elem = MP::EI::as_elem_native_unchecked();
70    let eo_elem = MP::EO::as_elem_native_unchecked();
71
72    let lhs_line_size =
73        tensor_line_size_parallel(R::line_size_elem(&ei_elem), &input.shape, &input.strides, 3);
74    let rhs_line_size = tensor_line_size_parallel(
75        R::line_size_elem(&ei_elem),
76        &weight.shape,
77        &weight.strides,
78        3,
79    );
80
81    let out_line_size =
82        tensor_line_size_parallel(R::line_size_elem(&eo_elem), out.shape, out.strides, 3);
83
84    let plane_dim = client
85        .properties()
86        .hardware_properties()
87        .defined_plane_size()
88        .unwrap_or(32);
89
90    let problem = ConvolutionProblem {
91        m: n * out_h * out_w,
92        n: out_c,
93        k: c * kh * kw,
94        lhs_layout: components::MatrixLayout::RowMajor,
95        rhs_layout: components::MatrixLayout::ColMajor,
96        lhs_line_size,
97        rhs_line_size,
98        out_line_size,
99        kernel_size: (kh as u32, kw as u32),
100        stride: (stride.0 as u32, stride.1 as u32),
101        padding: (padding.0 as i32, padding.1 as i32),
102        dilation: (dilation.0 as u32, dilation.1 as u32),
103
104        batches: n,
105        height: h,
106        width: w,
107        channels: c,
108
109        out_h,
110        out_w,
111    };
112
113    let (selection, config_input) = select_matmul::<Alg, R, MP>(client, &problem, plane_dim);
114
115    let launch = if TypeId::of::<MP::EI>() == TypeId::of::<f32>() {
116        if tf32::is_supported(client) {
117            launch_kernel::<R, (MP::EI, tf32, f32, MP::EO), Alg>
118        } else {
119            launch_kernel::<R, (MP::EI, f16, f32, MP::EO), Alg>
120        }
121    } else {
122        launch_kernel::<R, MP, Alg>
123    };
124
125    launch(
126        client,
127        &input.as_ref(),
128        &weight.as_ref(),
129        bias,
130        out,
131        problem,
132        selection,
133        config_input,
134    )
135}
136
137#[allow(clippy::result_large_err, clippy::too_many_arguments)]
138pub fn launch_kernel<R: Runtime, MP: MatmulPrecision, Alg: Algorithm>(
139    client: &ComputeClient<R::Server, R::Channel>,
140    input: &TensorHandleRef<'_, R>,
141    weight: &TensorHandleRef<'_, R>,
142    bias: &Option<TensorHandleRef<'_, R>>,
143    out: &TensorHandleRef<'_, R>,
144    problem: ConvolutionProblem,
145    selection: MatmulSelection,
146    config_input: StageInput,
147) -> Result<(), ConvLaunchError>
148where
149    Input<Alg, MP>: ConvInputsLaunch,
150    Output<Alg, MP>: ConcreteOutputFactory,
151{
152    // Reshape out to (M, N)
153    let out_shape = [out.shape[0..3].iter().product(), out.shape[3]];
154    let out_strides = [out.strides[2], out.strides[3]];
155
156    let out = unsafe {
157        TensorHandleRef::from_raw_parts(out.handle, &out_strides, &out_shape, out.elem_size)
158    };
159
160    let cube_dim = Alg::cube_dim(&selection);
161    let cube_count = Alg::cube_count(&selection, &problem);
162
163    let config = Alg::make_config(config_input, &problem, &cube_dim, &cube_count)
164        .map_err(MatmulLaunchError::InvalidConfig)?;
165
166    Alg::check_availability::<R, MP>(client, &config)?;
167
168    let input = <Input<Alg, MP> as ConvInputsLaunch>::create(input, weight, &selection, &problem);
169    let output = <Output<Alg, MP> as ConcreteOutputFactory>::create(
170        &out,
171        &selection,
172        &problem.as_matmul_problem(),
173    );
174    let bias = bias
175        .as_ref()
176        .map(|bias| bias.as_tensor_arg(problem.out_line_size));
177
178    unsafe {
179        Alg::GlobalConvolution::launch_unchecked::<(MP, Alg::Args), R>(
180            client, cube_dim, cube_count, input, bias, output, &problem, config,
181        );
182    }
183
184    Ok(())
185}