burn_cubecl/kernel/conv/conv2d/
base.rs1use burn_tensor::ops::{ConvOptions, ConvTransposeOptions};
2use cubecl::linalg::convolution::ConvLaunchError;
3
4use crate::{CubeRuntime, FloatElement, IntElement, tensor::CubeTensor};
5
6#[cfg(feature = "autotune")]
7use super::{conv_transpose2d_autotune, conv2d_autotune};
8use super::{
9 conv_transpose2d_col2im, conv_transpose2d_direct, conv2d_direct, conv2d_gemm_cyclic,
10 conv2d_im2col,
11};
12
13pub enum Conv2dStrategy {
15 Direct,
17 #[cfg(feature = "autotune")]
18 Autotune,
20 Gemm,
22 ImplicitGemm,
25}
26
27impl Default for Conv2dStrategy {
28 fn default() -> Self {
29 #[cfg(feature = "autotune")]
31 return Conv2dStrategy::Autotune;
32
33 #[cfg(not(feature = "autotune"))]
35 Conv2dStrategy::Direct
36 }
37}
38
39pub enum ConvTranspose2dStrategy {
41 Direct,
43 #[cfg(feature = "autotune")]
44 Autotune,
46 Gemm,
48}
49
50impl Default for ConvTranspose2dStrategy {
51 fn default() -> Self {
52 #[cfg(feature = "autotune")]
54 return ConvTranspose2dStrategy::Autotune;
55
56 #[cfg(not(feature = "autotune"))]
58 ConvTranspose2dStrategy::Direct
59 }
60}
61
62pub fn conv2d<R: CubeRuntime, E: FloatElement>(
71 input: CubeTensor<R>,
72 weight: CubeTensor<R>,
73 bias: Option<CubeTensor<R>>,
74 options: ConvOptions<2>,
75 strategy: Conv2dStrategy,
76) -> Result<CubeTensor<R>, ConvLaunchError> {
77 match strategy {
78 Conv2dStrategy::Direct => conv2d_direct::<R, E>(input, weight, bias, options),
79 #[cfg(feature = "autotune")]
80 Conv2dStrategy::Autotune => Ok(conv2d_autotune::<R, E>(input, weight, bias, options)),
81 Conv2dStrategy::Gemm => conv2d_im2col::<R, E>(input, weight, bias, options),
82 Conv2dStrategy::ImplicitGemm => conv2d_gemm_cyclic::<R, E>(input, weight, bias, options),
83 }
84}
85
86pub fn conv_transpose2d<R: CubeRuntime, E: FloatElement, I: IntElement>(
95 input: CubeTensor<R>,
96 weight: CubeTensor<R>,
97 bias: Option<CubeTensor<R>>,
98 options: ConvTransposeOptions<2>,
99 strategy: ConvTranspose2dStrategy,
100) -> Result<CubeTensor<R>, ConvLaunchError> {
101 match strategy {
102 ConvTranspose2dStrategy::Direct => {
103 conv_transpose2d_direct::<R, E>(input, weight, bias, options)
104 }
105 #[cfg(feature = "autotune")]
106 ConvTranspose2dStrategy::Autotune => Ok(conv_transpose2d_autotune::<R, E>(
107 input, weight, bias, options,
108 )),
109 ConvTranspose2dStrategy::Gemm => {
110 conv_transpose2d_col2im::<R, E>(input, weight, bias, options)
111 }
112 }
113}