burn_cubecl/kernel/conv/conv_transpose2d/
base.rs

1use crate::{CubeRuntime, tensor::CubeTensor};
2use burn_backend::ops::ConvTransposeOptions;
3use cubek::convolution::components::ConvSetupError;
4
5#[cfg(feature = "autotune")]
6use super::conv_transpose2d_autotune;
7use super::{conv_transpose2d_col2im, conv_transpose2d_direct};
8
9/// The strategy to be used when launching a conv_transpose kernel.
10pub enum ConvTranspose2dStrategy {
11    /// A simple direct convolution.
12    Direct,
13    #[cfg(feature = "autotune")]
14    /// Using autotune to choose the best kernel based on runtime information.
15    Autotune,
16    /// GEMM (im2col) based implementation of convolution. Significantly increased memory usage.
17    Gemm,
18}
19
20impl Default for ConvTranspose2dStrategy {
21    fn default() -> Self {
22        // if autotune is enabled, default to autotune
23        #[cfg(feature = "autotune")]
24        return ConvTranspose2dStrategy::Autotune;
25
26        // if autotune is disabled, default to the more memory-conservative algorithm
27        #[cfg(not(feature = "autotune"))]
28        ConvTranspose2dStrategy::Direct
29    }
30}
31
32/// Performs a 2D convolution with the given strategy
33///
34/// * `input` - The input feature map
35/// * `weight` - The weights (filter) applied to each kernel
36/// * `bias` - The bias added to each channel
37/// * `options` - The options to use for the convolution
38/// * `strategy` - The convolution algorithm to use. Autotune will pick the fastest available option.
39pub fn conv_transpose2d<R: CubeRuntime>(
40    input: CubeTensor<R>,
41    weight: CubeTensor<R>,
42    bias: Option<CubeTensor<R>>,
43    options: ConvTransposeOptions<2>,
44    strategy: ConvTranspose2dStrategy,
45) -> Result<CubeTensor<R>, ConvSetupError> {
46    match strategy {
47        ConvTranspose2dStrategy::Direct => conv_transpose2d_direct(input, weight, bias, options),
48        #[cfg(feature = "autotune")]
49        ConvTranspose2dStrategy::Autotune => {
50            Ok(conv_transpose2d_autotune(input, weight, bias, options))
51        }
52        ConvTranspose2dStrategy::Gemm => conv_transpose2d_col2im(input, weight, bias, options),
53    }
54}