burn_cubecl/kernel/conv/conv2d/
base.rs

1use burn_tensor::ops::{ConvOptions, ConvTransposeOptions};
2use cubecl::linalg::convolution::ConvLaunchError;
3
4use crate::{CubeRuntime, FloatElement, IntElement, tensor::CubeTensor};
5
6#[cfg(feature = "autotune")]
7use super::{conv_transpose2d_autotune, conv2d_autotune};
8use super::{
9    conv_transpose2d_col2im, conv_transpose2d_direct, conv2d_direct, conv2d_gemm_cyclic,
10    conv2d_im2col,
11};
12
13/// The strategy to be used when launching a convolution kernel.
14pub enum Conv2dStrategy {
15    /// A simple direct convolution.
16    Direct,
17    #[cfg(feature = "autotune")]
18    /// Using autotune to choose the best kernel based on runtime information.
19    Autotune,
20    /// GEMM (im2col) based implementation of convolution. Significantly increased memory usage.
21    Gemm,
22    /// Implicit GEMM implementation of convolution. Lower memory usage but requires CMMA and
23    /// has constraints on tensor shape.
24    ImplicitGemm,
25}
26
27impl Default for Conv2dStrategy {
28    fn default() -> Self {
29        // if autotune is enabled, default to autotune
30        #[cfg(feature = "autotune")]
31        return Conv2dStrategy::Autotune;
32
33        // if autotune is disabled, default to the more memory-conservative algorithm
34        #[cfg(not(feature = "autotune"))]
35        Conv2dStrategy::Direct
36    }
37}
38
39/// The strategy to be used when launching a conv_transpose kernel.
40pub enum ConvTranspose2dStrategy {
41    /// A simple direct convolution.
42    Direct,
43    #[cfg(feature = "autotune")]
44    /// Using autotune to choose the best kernel based on runtime information.
45    Autotune,
46    /// GEMM (im2col) based implementation of convolution. Significantly increased memory usage.
47    Gemm,
48}
49
50impl Default for ConvTranspose2dStrategy {
51    fn default() -> Self {
52        // if autotune is enabled, default to autotune
53        #[cfg(feature = "autotune")]
54        return ConvTranspose2dStrategy::Autotune;
55
56        // if autotune is disabled, default to the more memory-conservative algorithm
57        #[cfg(not(feature = "autotune"))]
58        ConvTranspose2dStrategy::Direct
59    }
60}
61
62/// Perform a 2D convolution with the given strategy
63///
64/// * `input` - The input feature map
65/// * `weight` - The weights (filter) applied to each kernel
66/// * `bias` - The bias added to each channel
67/// * `options` - The options to use for the convolution
68/// * `strategy` - The convolution algorithm to use. Autotune will pick the fastest available option.
69///
70pub fn conv2d<R: CubeRuntime, E: FloatElement>(
71    input: CubeTensor<R>,
72    weight: CubeTensor<R>,
73    bias: Option<CubeTensor<R>>,
74    options: ConvOptions<2>,
75    strategy: Conv2dStrategy,
76) -> Result<CubeTensor<R>, ConvLaunchError> {
77    match strategy {
78        Conv2dStrategy::Direct => conv2d_direct::<R, E>(input, weight, bias, options),
79        #[cfg(feature = "autotune")]
80        Conv2dStrategy::Autotune => Ok(conv2d_autotune::<R, E>(input, weight, bias, options)),
81        Conv2dStrategy::Gemm => conv2d_im2col::<R, E>(input, weight, bias, options),
82        Conv2dStrategy::ImplicitGemm => conv2d_gemm_cyclic::<R, E>(input, weight, bias, options),
83    }
84}
85
86/// Perform a 2D convolution with the given strategy
87///
88/// * `input` - The input feature map
89/// * `weight` - The weights (filter) applied to each kernel
90/// * `bias` - The bias added to each channel
91/// * `options` - The options to use for the convolution
92/// * `strategy` - The convolution algorithm to use. Autotune will pick the fastest available option.
93///
94pub fn conv_transpose2d<R: CubeRuntime, E: FloatElement, I: IntElement>(
95    input: CubeTensor<R>,
96    weight: CubeTensor<R>,
97    bias: Option<CubeTensor<R>>,
98    options: ConvTransposeOptions<2>,
99    strategy: ConvTranspose2dStrategy,
100) -> Result<CubeTensor<R>, ConvLaunchError> {
101    match strategy {
102        ConvTranspose2dStrategy::Direct => {
103            conv_transpose2d_direct::<R, E>(input, weight, bias, options)
104        }
105        #[cfg(feature = "autotune")]
106        ConvTranspose2dStrategy::Autotune => Ok(conv_transpose2d_autotune::<R, E>(
107            input, weight, bias, options,
108        )),
109        ConvTranspose2dStrategy::Gemm => {
110            conv_transpose2d_col2im::<R, E>(input, weight, bias, options)
111        }
112    }
113}