use super::{conv, pool};
use crate::ops::attention;
use crate::ops::unfold::unfold4d_using_conv2d;
use crate::tensor::{BoolTensor, FloatTensor, IntTensor};
use crate::{Backend, ElementConversion, TensorMetadata};
use burn_std::Shape;
use core::num::NonZeroUsize;
#[derive(new)]
pub struct Conv2dBackward<B: Backend> {
pub x_grad: FloatTensor<B>,
pub weights_grad: FloatTensor<B>,
pub bias_grad: Option<FloatTensor<B>>,
}
#[derive(new)]
pub struct DeformConv2dBackward<B: Backend> {
pub x_grad: FloatTensor<B>,
pub offset_grad: FloatTensor<B>,
pub weight_grad: FloatTensor<B>,
pub mask_grad: Option<FloatTensor<B>>,
pub bias_grad: Option<FloatTensor<B>>,
}
#[derive(new)]
pub struct Conv3dBackward<B: Backend> {
pub x_grad: FloatTensor<B>,
pub weights_grad: FloatTensor<B>,
pub bias_grad: Option<FloatTensor<B>>,
}
#[derive(new)]
pub struct MaxPool1dBackward<B: Backend> {
pub x_grad: FloatTensor<B>,
}
#[derive(new)]
pub struct MaxPool1dWithIndices<B: Backend> {
pub output: FloatTensor<B>,
pub indices: IntTensor<B>,
}
#[derive(new)]
pub struct MaxPool2dBackward<B: Backend> {
pub x_grad: FloatTensor<B>,
}
#[derive(new)]
pub struct MaxPool2dWithIndices<B: Backend> {
pub output: FloatTensor<B>,
pub indices: IntTensor<B>,
}
pub(crate) fn check_nonzero(value: usize, msg: &str) -> usize {
NonZeroUsize::new(value).expect(msg);
value
}
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
pub struct ConvOptions<const N: usize> {
pub stride: [usize; N],
pub padding: [usize; N],
pub dilation: [usize; N],
pub groups: usize,
}
impl<const N: usize> ConvOptions<N> {
pub fn new(
stride: [usize; N],
padding: [usize; N],
dilation: [usize; N],
groups: usize,
) -> Self {
Self {
stride: stride.map(|s| check_nonzero(s, "stride must be non-zero")),
padding,
dilation: dilation.map(|d| check_nonzero(d, "dilation must be non-zero")),
groups: check_nonzero(groups, "groups must be non-zero"),
}
}
}
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
pub struct DeformConvOptions<const N: usize> {
pub stride: [usize; N],
pub padding: [usize; N],
pub dilation: [usize; N],
pub weight_groups: usize,
pub offset_groups: usize,
}
impl<const N: usize> DeformConvOptions<N> {
pub fn new(
stride: [usize; N],
padding: [usize; N],
dilation: [usize; N],
weight_groups: usize,
offset_groups: usize,
) -> Self {
Self {
stride: stride.map(|s| check_nonzero(s, "stride must be non-zero")),
padding,
dilation: dilation.map(|d| check_nonzero(d, "dilation must be non-zero")),
weight_groups: check_nonzero(weight_groups, "weight groups must be non-zero"),
offset_groups: check_nonzero(offset_groups, "offset groups must be non-zero"),
}
}
}
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
pub struct ConvTransposeOptions<const N: usize> {
pub stride: [usize; N],
pub padding: [usize; N],
pub padding_out: [usize; N],
pub dilation: [usize; N],
pub groups: usize,
}
impl<const N: usize> ConvTransposeOptions<N> {
pub fn new(
stride: [usize; N],
padding: [usize; N],
padding_out: [usize; N],
dilation: [usize; N],
groups: usize,
) -> Self {
Self {
stride: stride.map(|s| check_nonzero(s, "stride must be non-zero")),
padding,
padding_out,
dilation: dilation.map(|d| check_nonzero(d, "dilation must be non-zero")),
groups: check_nonzero(groups, "groups must be non-zero"),
}
}
}
#[derive(Debug, Clone)]
pub struct UnfoldOptions {
pub stride: [usize; 2],
pub padding: [usize; 2],
pub dilation: [usize; 2],
}
impl UnfoldOptions {
pub fn new(stride: [usize; 2], padding: [usize; 2], dilation: [usize; 2]) -> Self {
Self {
stride: stride.map(|s| check_nonzero(s, "stride must be non-zero")),
padding,
dilation: dilation.map(|d| check_nonzero(d, "dilation must be non-zero")),
}
}
}
#[derive(new, Debug, Clone, serde::Deserialize, serde::Serialize)]
pub enum InterpolateMode {
Nearest,
Bilinear,
Bicubic,
}
#[derive(new, Debug, Clone)]
pub struct InterpolateOptions {
pub mode: InterpolateMode,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, serde::Deserialize, serde::Serialize)]
pub enum GridSamplePaddingMode {
#[default]
Zeros,
Border,
Reflection,
}
#[derive(Debug, Clone)]
pub struct GridSampleOptions {
pub mode: InterpolateMode,
pub padding_mode: GridSamplePaddingMode,
pub align_corners: bool,
}
impl Default for GridSampleOptions {
fn default() -> Self {
Self {
mode: InterpolateMode::Bilinear,
padding_mode: GridSamplePaddingMode::Zeros,
align_corners: false,
}
}
}
impl From<InterpolateMode> for GridSampleOptions {
fn from(value: InterpolateMode) -> Self {
GridSampleOptions::new(value)
}
}
impl GridSampleOptions {
pub fn new(mode: InterpolateMode) -> Self {
Self {
mode,
..Default::default()
}
}
pub fn with_padding_mode(mut self, padding_mode: GridSamplePaddingMode) -> Self {
self.padding_mode = padding_mode;
self
}
pub fn with_align_corners(mut self, align_corners: bool) -> Self {
self.align_corners = align_corners;
self
}
}
#[derive(Debug, Clone, Copy, PartialEq, serde::Deserialize, serde::Serialize)]
pub enum PadMode {
Constant(f32),
Reflect,
Edge,
}
impl Default for PadMode {
fn default() -> Self {
PadMode::Constant(0.0)
}
}
impl<E: ElementConversion> From<E> for PadMode {
fn from(value: E) -> Self {
PadMode::Constant(value.elem())
}
}
#[derive(new)]
pub struct InterpolateBackward<B: Backend> {
pub x_grad: FloatTensor<B>,
}
pub trait ModuleOps<B: Backend> {
fn embedding(weights: FloatTensor<B>, indices: IntTensor<B>) -> FloatTensor<B> {
let [batch_size, seq_length] = indices.shape().dims();
let [_, d_model] = weights.shape().dims();
let indices = B::int_reshape(indices, Shape::new([batch_size * seq_length]));
let output = B::float_select(weights, 0, indices);
B::float_reshape(output, Shape::new([batch_size, seq_length, d_model]))
}
fn embedding_backward(
weights: FloatTensor<B>,
output_grad: FloatTensor<B>,
indices: IntTensor<B>,
) -> FloatTensor<B> {
let [batch_size, seq_length] = indices.shape().dims();
let [n_embeddings, d_model] = weights.shape().dims();
let device = B::float_device(&weights);
let dtype = output_grad.dtype();
let indices = B::int_reshape(indices, Shape::new([batch_size * seq_length]));
let output_grad =
B::float_reshape(output_grad, Shape::new([batch_size * seq_length, d_model]));
let grad = B::float_zeros(Shape::new([n_embeddings, d_model]), &device, dtype.into());
B::float_select_add(grad, 0, indices, output_grad)
}
fn conv1d(
x: FloatTensor<B>,
weight: FloatTensor<B>,
bias: Option<FloatTensor<B>>,
options: ConvOptions<1>,
) -> FloatTensor<B> {
conv::conv1d_from_conv2d::<B>(x, weight, bias, options)
}
fn conv1d_x_backward(
x: FloatTensor<B>,
weight: FloatTensor<B>,
output_grad: FloatTensor<B>,
options: ConvOptions<1>,
) -> FloatTensor<B> {
conv::conv1d_x_backward::<B>(x, weight, output_grad, options)
}
fn conv1d_weight_backward(
x: FloatTensor<B>,
weight: FloatTensor<B>,
output_grad: FloatTensor<B>,
options: ConvOptions<1>,
) -> FloatTensor<B> {
conv::conv1d_weight_backward::<B>(x, weight, output_grad, options)
}
fn conv1d_bias_backward(
x: FloatTensor<B>,
bias: FloatTensor<B>,
output_grad: FloatTensor<B>,
) -> FloatTensor<B> {
conv::conv1d_bias_backward::<B>(x, bias, output_grad)
}
fn conv2d(
x: FloatTensor<B>,
weight: FloatTensor<B>,
bias: Option<FloatTensor<B>>,
options: ConvOptions<2>,
) -> FloatTensor<B>;
fn conv2d_x_backward(
x: FloatTensor<B>,
weight: FloatTensor<B>,
output_grad: FloatTensor<B>,
options: ConvOptions<2>,
) -> FloatTensor<B> {
conv::conv2d_x_backward::<B>(x, weight, output_grad, options)
}
fn conv2d_weight_backward(
x: FloatTensor<B>,
weight: FloatTensor<B>,
output_grad: FloatTensor<B>,
options: ConvOptions<2>,
) -> FloatTensor<B> {
conv::conv2d_weight_backward::<B>(x, weight, output_grad, options)
}
fn conv2d_bias_backward(
x: FloatTensor<B>,
weight: FloatTensor<B>,
bias: FloatTensor<B>,
output_grad: FloatTensor<B>,
) -> FloatTensor<B> {
conv::conv2d_bias_backward::<B>(x, weight, bias, output_grad)
}
fn deform_conv2d(
x: FloatTensor<B>,
offset: FloatTensor<B>,
weight: FloatTensor<B>,
mask: Option<FloatTensor<B>>,
bias: Option<FloatTensor<B>>,
options: DeformConvOptions<2>,
) -> FloatTensor<B>;
fn deform_conv2d_backward(
x: FloatTensor<B>,
offset: FloatTensor<B>,
weight: FloatTensor<B>,
mask: Option<FloatTensor<B>>,
bias: Option<FloatTensor<B>>,
output_grad: FloatTensor<B>,
options: DeformConvOptions<2>,
) -> DeformConv2dBackward<B>;
fn conv3d(
x: FloatTensor<B>,
weight: FloatTensor<B>,
bias: Option<FloatTensor<B>>,
options: ConvOptions<3>,
) -> FloatTensor<B>;
fn conv3d_x_backward(
x: FloatTensor<B>,
weight: FloatTensor<B>,
output_grad: FloatTensor<B>,
options: ConvOptions<3>,
) -> FloatTensor<B> {
conv::conv3d_x_backward::<B>(x, weight, output_grad, options)
}
fn conv3d_weight_backward(
x: FloatTensor<B>,
weight: FloatTensor<B>,
output_grad: FloatTensor<B>,
options: ConvOptions<3>,
) -> FloatTensor<B> {
conv::conv3d_weight_backward::<B>(x, weight, output_grad, options)
}
fn conv3d_bias_backward(
x: FloatTensor<B>,
weight: FloatTensor<B>,
bias: FloatTensor<B>,
output_grad: FloatTensor<B>,
) -> FloatTensor<B> {
conv::conv3d_bias_backward::<B>(x, weight, bias, output_grad)
}
fn conv_transpose1d(
x: FloatTensor<B>,
weight: FloatTensor<B>,
bias: Option<FloatTensor<B>>,
options: ConvTransposeOptions<1>,
) -> FloatTensor<B> {
conv::conv_transpose1d_from_conv_transpose2d::<B>(x, weight, bias, options)
}
fn conv_transpose1d_x_backward(
weight: FloatTensor<B>,
output_grad: FloatTensor<B>,
options: ConvTransposeOptions<1>,
) -> FloatTensor<B> {
conv::conv_transpose1d_x_backward::<B>(weight, output_grad, options)
}
fn conv_transpose1d_weight_backward(
x: FloatTensor<B>,
weight: FloatTensor<B>,
output_grad: FloatTensor<B>,
options: ConvTransposeOptions<1>,
) -> FloatTensor<B> {
conv::conv_transpose1d_weight_backward::<B>(x, weight, output_grad, options)
}
fn conv_transpose1d_bias_backward(
x: FloatTensor<B>,
bias: FloatTensor<B>,
output_grad: FloatTensor<B>,
) -> FloatTensor<B> {
conv::conv_transpose1d_bias_backward::<B>(x, bias, output_grad)
}
fn conv_transpose2d(
x: FloatTensor<B>,
weight: FloatTensor<B>,
bias: Option<FloatTensor<B>>,
options: ConvTransposeOptions<2>,
) -> FloatTensor<B>;
fn conv_transpose2d_x_backward(
weight: FloatTensor<B>,
output_grad: FloatTensor<B>,
options: ConvTransposeOptions<2>,
) -> FloatTensor<B> {
conv::conv_transpose2d_x_backward::<B>(weight, output_grad, options)
}
fn conv_transpose2d_weight_backward(
x: FloatTensor<B>,
weight: FloatTensor<B>,
output_grad: FloatTensor<B>,
options: ConvTransposeOptions<2>,
) -> FloatTensor<B> {
conv::conv_transpose2d_weight_backward::<B>(x, weight, output_grad, options)
}
fn conv_transpose2d_bias_backward(
x: FloatTensor<B>,
bias: FloatTensor<B>,
output_grad: FloatTensor<B>,
) -> FloatTensor<B> {
conv::conv_transpose2d_bias_backward::<B>(x, bias, output_grad)
}
fn conv_transpose3d(
x: FloatTensor<B>,
weight: FloatTensor<B>,
bias: Option<FloatTensor<B>>,
options: ConvTransposeOptions<3>,
) -> FloatTensor<B>;
fn conv_transpose3d_x_backward(
weight: FloatTensor<B>,
output_grad: FloatTensor<B>,
options: ConvTransposeOptions<3>,
) -> FloatTensor<B> {
conv::conv_transpose3d_x_backward::<B>(weight, output_grad, options)
}
fn conv_transpose3d_weight_backward(
x: FloatTensor<B>,
weight: FloatTensor<B>,
output_grad: FloatTensor<B>,
options: ConvTransposeOptions<3>,
) -> FloatTensor<B> {
conv::conv_transpose3d_weight_backward::<B>(x, weight, output_grad, options)
}
fn conv_transpose3d_bias_backward(
x: FloatTensor<B>,
bias: FloatTensor<B>,
output_grad: FloatTensor<B>,
) -> FloatTensor<B> {
conv::conv_transpose3d_bias_backward::<B>(x, bias, output_grad)
}
fn unfold4d(
x: FloatTensor<B>,
kernel_size: [usize; 2],
options: UnfoldOptions,
) -> FloatTensor<B> {
if options.padding == [0, 0] && options.dilation == [1, 1] {
let blocks = B::float_unfold(x, 2, kernel_size[0], options.stride[0]);
let blocks = B::float_unfold(blocks, 3, kernel_size[1], options.stride[1]);
let blocks = B::float_permute(blocks, &[0, 1, 4, 5, 2, 3]);
let shape = &blocks.shape().dims;
B::float_reshape(
blocks,
[
shape[0],
shape[1] * shape[2] * shape[3],
shape[4] * shape[5],
]
.into(),
)
} else {
unfold4d_using_conv2d::<B>(x, kernel_size, options)
}
}
fn avg_pool1d(
x: FloatTensor<B>,
kernel_size: usize,
stride: usize,
padding: usize,
count_include_pad: bool,
ceil_mode: bool,
) -> FloatTensor<B> {
pool::avg_pool1d_from_2d::<B>(
x,
kernel_size,
stride,
padding,
count_include_pad,
ceil_mode,
)
}
fn avg_pool1d_backward(
x: FloatTensor<B>,
grad: FloatTensor<B>,
kernel_size: usize,
stride: usize,
padding: usize,
count_include_pad: bool,
ceil_mode: bool,
) -> FloatTensor<B> {
pool::avg_pool1d_backward_from_2d::<B>(
x,
grad,
kernel_size,
stride,
padding,
count_include_pad,
ceil_mode,
)
}
fn avg_pool2d(
x: FloatTensor<B>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
count_include_pad: bool,
ceil_mode: bool,
) -> FloatTensor<B>;
fn avg_pool2d_backward(
x: FloatTensor<B>,
grad: FloatTensor<B>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
count_include_pad: bool,
ceil_mode: bool,
) -> FloatTensor<B>;
fn adaptive_avg_pool2d(x: FloatTensor<B>, output_size: [usize; 2]) -> FloatTensor<B>;
fn adaptive_avg_pool2d_backward(x: FloatTensor<B>, grad: FloatTensor<B>) -> FloatTensor<B>;
fn adaptive_avg_pool1d(x: FloatTensor<B>, output_size: usize) -> FloatTensor<B> {
pool::adaptive_avg_pool1d_from_2d::<B>(x, output_size)
}
fn adaptive_avg_pool1d_backward(x: FloatTensor<B>, grad: FloatTensor<B>) -> FloatTensor<B> {
pool::adaptive_avg_pool1d_backward_from_2d::<B>(x, grad)
}
fn max_pool1d(
x: FloatTensor<B>,
kernel_size: usize,
stride: usize,
padding: usize,
dilation: usize,
ceil_mode: bool,
) -> FloatTensor<B> {
pool::max_pool1d_from_2d::<B>(x, kernel_size, stride, padding, dilation, ceil_mode)
}
fn max_pool1d_with_indices(
x: FloatTensor<B>,
kernel_size: usize,
stride: usize,
padding: usize,
dilation: usize,
ceil_mode: bool,
) -> MaxPool1dWithIndices<B> {
pool::max_pool1d_with_indices_from_2d::<B>(
x,
kernel_size,
stride,
padding,
dilation,
ceil_mode,
)
}
#[allow(clippy::too_many_arguments)]
fn max_pool1d_with_indices_backward(
x: FloatTensor<B>,
kernel_size: usize,
stride: usize,
padding: usize,
dilation: usize,
ceil_mode: bool,
output_grad: FloatTensor<B>,
indices: IntTensor<B>,
) -> MaxPool1dBackward<B> {
pool::max_pool1d_with_indices_backward_from_2d::<B>(
x,
kernel_size,
stride,
padding,
dilation,
ceil_mode,
output_grad,
indices,
)
}
fn max_pool2d(
x: FloatTensor<B>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
ceil_mode: bool,
) -> FloatTensor<B>;
fn max_pool2d_with_indices(
x: FloatTensor<B>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
ceil_mode: bool,
) -> MaxPool2dWithIndices<B>;
#[allow(clippy::too_many_arguments)]
fn max_pool2d_with_indices_backward(
x: FloatTensor<B>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
ceil_mode: bool,
output_grad: FloatTensor<B>,
indices: IntTensor<B>,
) -> MaxPool2dBackward<B>;
fn interpolate(
x: FloatTensor<B>,
output_size: [usize; 2],
options: InterpolateOptions,
) -> FloatTensor<B>;
fn interpolate_backward(
x: FloatTensor<B>,
grad: FloatTensor<B>,
output_size: [usize; 2],
options: InterpolateOptions,
) -> FloatTensor<B>;
fn attention(
query: FloatTensor<B>,
key: FloatTensor<B>,
value: FloatTensor<B>,
mask: Option<BoolTensor<B>>,
) -> FloatTensor<B> {
attention::naive_attention::<B>(query, key, value, mask)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[should_panic = "stride must be non-zero"]
fn conv_options_stride_zero() {
let _opt = ConvOptions::new([0, 1], [0, 0], [1, 1], 1);
}
#[test]
#[should_panic = "dilation must be non-zero"]
fn conv_options_dilation_zero() {
let _opt = ConvOptions::new([1, 1], [0, 0], [0, 0], 1);
}
#[test]
#[should_panic = "groups must be non-zero"]
fn conv_options_groups_zero() {
let _opt = ConvOptions::new([1, 1], [0, 0], [1, 1], 0);
}
#[test]
#[should_panic = "stride must be non-zero"]
fn conv_transpose_options_stride_zero() {
let _opt = ConvTransposeOptions::new([0, 1], [0, 0], [0, 0], [1, 1], 1);
}
#[test]
#[should_panic = "dilation must be non-zero"]
fn conv_transpose_options_dilation_zero() {
let _opt = ConvTransposeOptions::new([1, 1], [0, 0], [0, 0], [0, 0], 1);
}
#[test]
#[should_panic = "groups must be non-zero"]
fn conv_transpose_options_groups_zero() {
let _opt = ConvTransposeOptions::new([1, 1], [0, 0], [0, 0], [1, 1], 0);
}
#[test]
#[should_panic = "stride must be non-zero"]
fn deform_conv_options_stride_zero() {
let _opt = DeformConvOptions::new([0, 1], [0, 0], [1, 1], 1, 1);
}
#[test]
#[should_panic = "dilation must be non-zero"]
fn deform_conv_options_dilation_zero() {
let _opt = DeformConvOptions::new([1, 1], [0, 0], [0, 0], 1, 1);
}
#[test]
#[should_panic = "weight groups must be non-zero"]
fn deform_conv_options_weights_groups_zero() {
let _opt = DeformConvOptions::new([1, 1], [0, 0], [1, 1], 0, 1);
}
#[test]
#[should_panic = "offset groups must be non-zero"]
fn deform_conv_options_offset_groups_zero() {
let _opt = DeformConvOptions::new([1, 1], [0, 0], [1, 1], 1, 0);
}
#[test]
#[should_panic = "stride must be non-zero"]
fn unfold_options_stride_zero() {
let _opt = UnfoldOptions::new([0, 1], [0, 0], [1, 1]);
}
#[test]
#[should_panic = "dilation must be non-zero"]
fn unfold_options_dilation_zero() {
let _opt = UnfoldOptions::new([1, 1], [0, 0], [0, 0]);
}
}