use burn_backend::ops::ConvOptions;
use burn_std::Shape;
use cubek::{
convolution::{
AcceleratedTileKind, ConvAlgorithm, ConvolutionArgs, ConvolutionInputs, Strategy,
components::ConvSetupError, launch_ref,
},
matmul::definition::{MatmulElems, MatmulGlobalElems},
std::InputBinding,
};
use crate::{CubeRuntime, ops::numeric::empty_device_dtype, tensor::CubeTensor};
pub fn dgrad_gemm_simple_sync<R: CubeRuntime, const N: usize>(
out_grad: CubeTensor<R>,
weights: CubeTensor<R>,
input_shape: Shape,
options: ConvOptions<N>,
tile_kind: AcceleratedTileKind,
) -> Result<CubeTensor<R>, ConvSetupError> {
let strategy = match tile_kind {
AcceleratedTileKind::Cmma => Strategy::Inferred {
algorithm: ConvAlgorithm::SimpleSyncCyclic,
tile_kind,
},
AcceleratedTileKind::Mma => Strategy::Inferred {
algorithm: ConvAlgorithm::SimpleSyncStrided,
tile_kind,
},
};
launch_backwards_data::<R, N>(&strategy, out_grad, weights, input_shape, options)
}
pub fn dgrad_gemm_simple_async<R: CubeRuntime, const N: usize>(
out_grad: CubeTensor<R>,
weights: CubeTensor<R>,
input_shape: Shape,
options: ConvOptions<N>,
tile_kind: AcceleratedTileKind,
) -> Result<CubeTensor<R>, ConvSetupError> {
let strategy = match tile_kind {
AcceleratedTileKind::Cmma => Strategy::Inferred {
algorithm: ConvAlgorithm::SimpleAsyncCyclic,
tile_kind,
},
AcceleratedTileKind::Mma => Strategy::Inferred {
algorithm: ConvAlgorithm::SimpleAsyncStrided,
tile_kind,
},
};
launch_backwards_data::<R, N>(&strategy, out_grad, weights, input_shape, options)
}
pub fn dgrad_gemm_simple_tma<R: CubeRuntime, const N: usize>(
out_grad: CubeTensor<R>,
weights: CubeTensor<R>,
input_shape: Shape,
options: ConvOptions<N>,
tile_kind: AcceleratedTileKind,
) -> Result<CubeTensor<R>, ConvSetupError> {
launch_backwards_data::<R, N>(
&Strategy::Inferred {
algorithm: ConvAlgorithm::SimpleAsyncTma,
tile_kind,
},
out_grad,
weights,
input_shape,
options,
)
}
pub fn launch_backwards_data<R: CubeRuntime, const N: usize>(
strategy: &Strategy,
out_grad: CubeTensor<R>,
weights: CubeTensor<R>,
input_shape: Shape,
options: ConvOptions<N>,
) -> Result<CubeTensor<R>, ConvSetupError> {
if options.groups != 1 || options.stride.iter().any(|&s| s != 1) {
return Err(ConvSetupError::Groups(options.groups));
}
let out_dtype = out_grad.dtype;
let in_grad = empty_device_dtype(
out_grad.client.clone(),
out_grad.device.clone(),
input_shape,
out_dtype,
);
let client = out_grad.client.clone();
let dtypes = MatmulElems::from_globals(&MatmulGlobalElems {
lhs: out_grad.dtype.into(),
rhs: weights.dtype.into(),
out: out_dtype.into(),
});
let out_grad_dtype = out_grad.dtype;
let weights_dtype = weights.dtype;
let out_grad = InputBinding::new(out_grad.binding(), out_grad_dtype.into());
let weights = InputBinding::new(weights.binding(), weights_dtype.into());
launch_ref::<R, N>(
strategy,
&client,
ConvolutionInputs::BackwardData {
out_grad,
weights,
in_grad: in_grad.clone().binding(),
},
ConvolutionArgs {
stride: options.stride,
padding: options.padding,
dilation: options.dilation,
},
dtypes,
)?;
Ok(in_grad)
}