burn_jit/kernel/conv/conv2d/
base.rs1use burn_tensor::ops::{ConvOptions, ConvTransposeOptions};
2
3use crate::{
4 kernel::conv::ConvLaunchError, tensor::JitTensor, FloatElement, IntElement, JitRuntime,
5};
6
7#[cfg(feature = "autotune")]
8use super::{conv2d_autotune, conv_transpose2d_autotune};
9use super::{
10 conv2d_direct, conv2d_im2col, conv_transpose2d_col2im, conv_transpose2d_direct,
11 gemm::launch::conv2d_gemm_cmma_large_m, implicit_gemm::conv2d_implicit_gemm,
12};
13
14pub enum Conv2dStrategy {
16 Direct,
18 #[cfg(feature = "autotune")]
19 Autotune,
21 Gemm,
23 ImplicitGemm,
26 ImplicitGemmComplex,
29}
30
31impl Default for Conv2dStrategy {
32 fn default() -> Self {
33 #[cfg(feature = "autotune")]
35 return Conv2dStrategy::Autotune;
36
37 #[cfg(not(feature = "autotune"))]
39 Conv2dStrategy::Direct
40 }
41}
42
43pub enum ConvTranspose2dStrategy {
45 Direct,
47 #[cfg(feature = "autotune")]
48 Autotune,
50 Gemm,
52}
53
54impl Default for ConvTranspose2dStrategy {
55 fn default() -> Self {
56 #[cfg(feature = "autotune")]
58 return ConvTranspose2dStrategy::Autotune;
59
60 #[cfg(not(feature = "autotune"))]
62 ConvTranspose2dStrategy::Direct
63 }
64}
65
66pub fn conv2d<R: JitRuntime, E: FloatElement>(
75 input: JitTensor<R>,
76 weight: JitTensor<R>,
77 bias: Option<JitTensor<R>>,
78 options: ConvOptions<2>,
79 strategy: Conv2dStrategy,
80) -> Result<JitTensor<R>, ConvLaunchError> {
81 match strategy {
82 Conv2dStrategy::Direct => conv2d_direct::<R, E>(input, weight, bias, options),
83 #[cfg(feature = "autotune")]
84 Conv2dStrategy::Autotune => Ok(conv2d_autotune::<R, E>(input, weight, bias, options)),
85 Conv2dStrategy::Gemm => conv2d_im2col::<R, E>(input, weight, bias, options),
86 Conv2dStrategy::ImplicitGemm => conv2d_implicit_gemm::<R, E>(input, weight, bias, options),
87 Conv2dStrategy::ImplicitGemmComplex => {
88 conv2d_gemm_cmma_large_m::<R, E>(input, weight, bias, options)
89 }
90 }
91}
92
93pub fn conv_transpose2d<R: JitRuntime, E: FloatElement, I: IntElement>(
102 input: JitTensor<R>,
103 weight: JitTensor<R>,
104 bias: Option<JitTensor<R>>,
105 options: ConvTransposeOptions<2>,
106 strategy: ConvTranspose2dStrategy,
107) -> Result<JitTensor<R>, ConvLaunchError> {
108 match strategy {
109 ConvTranspose2dStrategy::Direct => {
110 conv_transpose2d_direct::<R, E>(input, weight, bias, options)
111 }
112 #[cfg(feature = "autotune")]
113 ConvTranspose2dStrategy::Autotune => Ok(conv_transpose2d_autotune::<R, E>(
114 input, weight, bias, options,
115 )),
116 ConvTranspose2dStrategy::Gemm => {
117 conv_transpose2d_col2im::<R, E>(input, weight, bias, options)
118 }
119 }
120}