use burn_backend::ops::ConvOptions;
use burn_std::Shape;
use cubek::convolution::{AcceleratedTileKind, components::ConvSetupError};
#[cfg(feature = "autotune")]
use crate::kernel::conv::{backward_weight::wgrad_autotune, dgrad_autotune};
use crate::{
CubeRuntime,
kernel::conv::{
backward_data::{fallback::conv_data_backward_fallback, implicit_gemm::*},
backward_weight::{fallback::conv_weight_backward_fallback, implicit_gemm::*},
forward::implicit_gemm::conv_gemm_simple_sync,
},
ops::{permute_nchw_to_nhwc, permute_nchw_to_nhwc_shape, permute_nhwc_to_nchw},
tensor::CubeTensor,
};
use super::conv_direct;
#[cfg(feature = "autotune")]
use super::forward::conv_autotune;
pub enum ConvStrategy {
Direct,
#[cfg(feature = "autotune")]
Autotune,
ImplicitGemm,
}
impl Default for ConvStrategy {
fn default() -> Self {
#[cfg(feature = "autotune")]
return ConvStrategy::Autotune;
#[cfg(not(feature = "autotune"))]
ConvStrategy::Direct
}
}
pub fn conv_forward<R: CubeRuntime, const N: usize>(
input: CubeTensor<R>,
weight: CubeTensor<R>,
bias: Option<CubeTensor<R>>,
options: ConvOptions<N>,
strategy: ConvStrategy,
) -> Result<CubeTensor<R>, ConvSetupError> {
let input = permute_nchw_to_nhwc(input);
let weight = permute_nchw_to_nhwc(weight);
let out = conv_forward_nhwc(input, weight, bias, options, strategy)?;
Ok(permute_nhwc_to_nchw(out))
}
pub fn conv_forward_nhwc<R: CubeRuntime, const N: usize>(
input: CubeTensor<R>,
weight: CubeTensor<R>,
bias: Option<CubeTensor<R>>,
options: ConvOptions<N>,
strategy: ConvStrategy,
) -> Result<CubeTensor<R>, ConvSetupError> {
match strategy {
ConvStrategy::Direct => conv_direct::<R, N>(input, weight, bias, options),
#[cfg(feature = "autotune")]
ConvStrategy::Autotune => Ok(conv_autotune::<R, N>(input, weight, bias, options)),
ConvStrategy::ImplicitGemm => {
if options.groups != 1 {
conv_direct::<R, N>(input, weight, bias, options)
} else {
conv_gemm_simple_sync::<R, N>(
input,
weight,
bias,
options,
AcceleratedTileKind::Cmma,
)
}
}
}
}
pub fn conv_weight_backward<R: CubeRuntime, const N: usize>(
input: CubeTensor<R>,
out_grad: CubeTensor<R>,
weight_shape: Shape,
options: ConvOptions<N>,
strategy: ConvStrategy,
) -> Result<CubeTensor<R>, ConvSetupError> {
let input = permute_nchw_to_nhwc(input);
let out_grad = permute_nchw_to_nhwc(out_grad);
let weight_shape = permute_nchw_to_nhwc_shape(weight_shape);
let weight_grad = match strategy {
ConvStrategy::Direct => {
conv_weight_backward_fallback::<R, N>(input, out_grad, weight_shape, options)
}
#[cfg(feature = "autotune")]
ConvStrategy::Autotune => Ok(wgrad_autotune::<R, N>(
input,
out_grad,
weight_shape,
options,
)),
ConvStrategy::ImplicitGemm => {
if options.groups != 1 {
conv_weight_backward_fallback::<R, N>(input, out_grad, weight_shape, options)
} else {
wgrad_gemm_simple_sync::<R, N>(
input,
out_grad,
weight_shape,
options,
AcceleratedTileKind::Cmma,
)
}
}
}?;
Ok(permute_nhwc_to_nchw(weight_grad))
}
pub fn conv_data_backward<R: CubeRuntime, const N: usize>(
out_grad: CubeTensor<R>,
weights: CubeTensor<R>,
in_shape: Shape,
options: ConvOptions<N>,
strategy: ConvStrategy,
) -> Result<CubeTensor<R>, ConvSetupError> {
let out_grad = permute_nchw_to_nhwc(out_grad);
let weights = permute_nchw_to_nhwc(weights);
let in_shape = permute_nchw_to_nhwc_shape(in_shape);
let weight_grad = match strategy {
ConvStrategy::Direct => {
conv_data_backward_fallback::<R, N>(out_grad, weights, in_shape, options)?
}
#[cfg(feature = "autotune")]
ConvStrategy::Autotune => dgrad_autotune::<R, N>(out_grad, weights, in_shape, options),
ConvStrategy::ImplicitGemm => {
if options.groups != 1 || options.stride.iter().any(|&s| s != 1) {
conv_data_backward_fallback::<R, N>(out_grad, weights, in_shape, options)?
} else {
dgrad_gemm_simple_sync::<R, N>(
out_grad,
weights,
in_shape,
options,
AcceleratedTileKind::Cmma,
)?
}
}
};
Ok(permute_nhwc_to_nchw(weight_grad))
}