use crate::{CubeRuntime, tensor::CubeTensor};
use burn_backend::ops::ConvTransposeOptions;
use cubek::convolution::components::ConvSetupError;
#[cfg(feature = "autotune")]
use super::conv_transpose2d_autotune;
use super::{conv_transpose2d_col2im, conv_transpose2d_direct};
pub enum ConvTranspose2dStrategy {
Direct,
#[cfg(feature = "autotune")]
Autotune,
Gemm,
}
impl Default for ConvTranspose2dStrategy {
fn default() -> Self {
#[cfg(feature = "autotune")]
return ConvTranspose2dStrategy::Autotune;
#[cfg(not(feature = "autotune"))]
ConvTranspose2dStrategy::Direct
}
}
pub fn conv_transpose2d<R: CubeRuntime>(
input: CubeTensor<R>,
weight: CubeTensor<R>,
bias: Option<CubeTensor<R>>,
options: ConvTransposeOptions<2>,
strategy: ConvTranspose2dStrategy,
) -> Result<CubeTensor<R>, ConvSetupError> {
match strategy {
ConvTranspose2dStrategy::Direct => conv_transpose2d_direct(input, weight, bias, options),
#[cfg(feature = "autotune")]
ConvTranspose2dStrategy::Autotune => {
Ok(conv_transpose2d_autotune(input, weight, bias, options))
}
ConvTranspose2dStrategy::Gemm => conv_transpose2d_col2im(input, weight, bias, options),
}
}