burn_jit/kernel/conv/conv2d/
base.rs

1use burn_tensor::ops::{ConvOptions, ConvTransposeOptions};
2
3use crate::{
4    kernel::conv::ConvLaunchError, tensor::JitTensor, FloatElement, IntElement, JitRuntime,
5};
6
7#[cfg(feature = "autotune")]
8use super::{conv2d_autotune, conv_transpose2d_autotune};
9use super::{
10    conv2d_direct, conv2d_im2col, conv_transpose2d_col2im, conv_transpose2d_direct,
11    gemm::launch::conv2d_gemm_cmma_large_m, implicit_gemm::conv2d_implicit_gemm,
12};
13
14/// The strategy to be used when launching a convolution kernel.
15pub enum Conv2dStrategy {
16    /// A simple direct convolution.
17    Direct,
18    #[cfg(feature = "autotune")]
19    /// Using autotune to choose the best kernel based on runtime information.
20    Autotune,
21    /// GEMM (im2col) based implementation of convolution. Significantly increased memory usage.
22    Gemm,
23    /// Implicit GEMM implementation of convolution. Lower memory usage but requires CMMA and
24    /// has constraints on tensor shape.
25    ImplicitGemm,
26    /// Implicit GEMM implementation of convolution. Uses `cubecl` matmul components to provide
27    /// the flexibility needed to work well for varied problem sizes.
28    ImplicitGemmComplex,
29}
30
31impl Default for Conv2dStrategy {
32    fn default() -> Self {
33        // if autotune is enabled, default to autotune
34        #[cfg(feature = "autotune")]
35        return Conv2dStrategy::Autotune;
36
37        // if autotune is disabled, default to the more memory-conservative algorithm
38        #[cfg(not(feature = "autotune"))]
39        Conv2dStrategy::Direct
40    }
41}
42
43/// The strategy to be used when launching a conv_transpose kernel.
44pub enum ConvTranspose2dStrategy {
45    /// A simple direct convolution.
46    Direct,
47    #[cfg(feature = "autotune")]
48    /// Using autotune to choose the best kernel based on runtime information.
49    Autotune,
50    /// GEMM (im2col) based implementation of convolution. Significantly increased memory usage.
51    Gemm,
52}
53
54impl Default for ConvTranspose2dStrategy {
55    fn default() -> Self {
56        // if autotune is enabled, default to autotune
57        #[cfg(feature = "autotune")]
58        return ConvTranspose2dStrategy::Autotune;
59
60        // if autotune is disabled, default to the more memory-conservative algorithm
61        #[cfg(not(feature = "autotune"))]
62        ConvTranspose2dStrategy::Direct
63    }
64}
65
66/// Perform a 2D convolution with the given strategy
67///
68/// * `input` - The input feature map
69/// * `weight` - The weights (filter) applied to each kernel
70/// * `bias` - The bias added to each channel
71/// * `options` - The options to use for the convolution
72/// * `strategy` - The convolution algorithm to use. Autotune will pick the fastest available option.
73///
74pub fn conv2d<R: JitRuntime, E: FloatElement>(
75    input: JitTensor<R>,
76    weight: JitTensor<R>,
77    bias: Option<JitTensor<R>>,
78    options: ConvOptions<2>,
79    strategy: Conv2dStrategy,
80) -> Result<JitTensor<R>, ConvLaunchError> {
81    match strategy {
82        Conv2dStrategy::Direct => conv2d_direct::<R, E>(input, weight, bias, options),
83        #[cfg(feature = "autotune")]
84        Conv2dStrategy::Autotune => Ok(conv2d_autotune::<R, E>(input, weight, bias, options)),
85        Conv2dStrategy::Gemm => conv2d_im2col::<R, E>(input, weight, bias, options),
86        Conv2dStrategy::ImplicitGemm => conv2d_implicit_gemm::<R, E>(input, weight, bias, options),
87        Conv2dStrategy::ImplicitGemmComplex => {
88            conv2d_gemm_cmma_large_m::<R, E>(input, weight, bias, options)
89        }
90    }
91}
92
93/// Perform a 2D convolution with the given strategy
94///
95/// * `input` - The input feature map
96/// * `weight` - The weights (filter) applied to each kernel
97/// * `bias` - The bias added to each channel
98/// * `options` - The options to use for the convolution
99/// * `strategy` - The convolution algorithm to use. Autotune will pick the fastest available option.
100///
101pub fn conv_transpose2d<R: JitRuntime, E: FloatElement, I: IntElement>(
102    input: JitTensor<R>,
103    weight: JitTensor<R>,
104    bias: Option<JitTensor<R>>,
105    options: ConvTransposeOptions<2>,
106    strategy: ConvTranspose2dStrategy,
107) -> Result<JitTensor<R>, ConvLaunchError> {
108    match strategy {
109        ConvTranspose2dStrategy::Direct => {
110            conv_transpose2d_direct::<R, E>(input, weight, bias, options)
111        }
112        #[cfg(feature = "autotune")]
113        ConvTranspose2dStrategy::Autotune => Ok(conv_transpose2d_autotune::<R, E>(
114            input, weight, bias, options,
115        )),
116        ConvTranspose2dStrategy::Gemm => {
117            conv_transpose2d_col2im::<R, E>(input, weight, bias, options)
118        }
119    }
120}