burn_cubecl/kernel/conv/conv_transpose2d/
base.rs1use crate::{CubeRuntime, tensor::CubeTensor};
2use burn_backend::ops::ConvTransposeOptions;
3use cubek::convolution::components::ConvSetupError;
4
5#[cfg(feature = "autotune")]
6use super::conv_transpose2d_autotune;
7use super::{conv_transpose2d_col2im, conv_transpose2d_direct};
8
9pub enum ConvTranspose2dStrategy {
11 Direct,
13 #[cfg(feature = "autotune")]
14 Autotune,
16 Gemm,
18}
19
20impl Default for ConvTranspose2dStrategy {
21 fn default() -> Self {
22 #[cfg(feature = "autotune")]
24 return ConvTranspose2dStrategy::Autotune;
25
26 #[cfg(not(feature = "autotune"))]
28 ConvTranspose2dStrategy::Direct
29 }
30}
31
32pub fn conv_transpose2d<R: CubeRuntime>(
40 input: CubeTensor<R>,
41 weight: CubeTensor<R>,
42 bias: Option<CubeTensor<R>>,
43 options: ConvTransposeOptions<2>,
44 strategy: ConvTranspose2dStrategy,
45) -> Result<CubeTensor<R>, ConvSetupError> {
46 match strategy {
47 ConvTranspose2dStrategy::Direct => conv_transpose2d_direct(input, weight, bias, options),
48 #[cfg(feature = "autotune")]
49 ConvTranspose2dStrategy::Autotune => {
50 Ok(conv_transpose2d_autotune(input, weight, bias, options))
51 }
52 ConvTranspose2dStrategy::Gemm => conv_transpose2d_col2im(input, weight, bias, options),
53 }
54}