use super::conv;
use crate::backend::Backend;
#[derive(new)]
pub struct Conv2dBackward<B: Backend> {
pub x_grad: B::TensorPrimitive<4>,
pub weights_grad: B::TensorPrimitive<4>,
pub bias_grad: Option<B::TensorPrimitive<1>>,
}
#[derive(new)]
pub struct MaxPool2dBackward<B: Backend> {
pub x_grad: B::TensorPrimitive<4>,
}
#[derive(new)]
pub struct MaxPool2dWithIndexes<B: Backend> {
pub output: B::TensorPrimitive<4>,
pub indexes: B::IntTensorPrimitive<4>,
}
#[derive(new)]
pub struct Conv1dBackward<B: Backend> {
pub x_grad: B::TensorPrimitive<3>,
pub weights_grad: B::TensorPrimitive<3>,
pub bias_grad: Option<B::TensorPrimitive<1>>,
}
#[derive(new, Debug, Clone)]
pub struct ConvOptions<const N: usize> {
pub stride: [usize; N],
pub padding: [usize; N],
pub dilation: [usize; N],
pub groups: usize,
}
#[derive(new, Debug, Clone)]
pub struct ConvTransposeOptions<const N: usize> {
pub stride: [usize; N],
pub padding: [usize; N],
pub padding_out: [usize; N],
pub dilation: [usize; N],
pub groups: usize,
}
pub trait ModuleOps<B: Backend> {
fn embedding(
weights: B::TensorPrimitive<2>,
indexes: B::IntTensorPrimitive<2>,
) -> B::TensorPrimitive<3>;
fn embedding_backward(
weights: B::TensorPrimitive<2>,
output: B::TensorPrimitive<3>,
indexes: B::IntTensorPrimitive<2>,
) -> B::TensorPrimitive<2>;
fn conv2d(
x: B::TensorPrimitive<4>,
weight: B::TensorPrimitive<4>,
bias: Option<B::TensorPrimitive<1>>,
options: ConvOptions<2>,
) -> B::TensorPrimitive<4>;
fn conv_transpose2d(
x: B::TensorPrimitive<4>,
weight: B::TensorPrimitive<4>,
bias: Option<B::TensorPrimitive<1>>,
options: ConvTransposeOptions<2>,
) -> B::TensorPrimitive<4>;
fn conv2d_backward(
x: B::TensorPrimitive<4>,
weight: B::TensorPrimitive<4>,
bias: Option<B::TensorPrimitive<1>>,
output_grad: B::TensorPrimitive<4>,
options: ConvOptions<2>,
) -> Conv2dBackward<B> {
conv::conv2d_backward(x, weight, bias, output_grad, options)
}
fn conv1d(
x: B::TensorPrimitive<3>,
weight: B::TensorPrimitive<3>,
bias: Option<B::TensorPrimitive<1>>,
options: ConvOptions<1>,
) -> B::TensorPrimitive<3> {
conv::conv1d_from_conv2d::<B>(x, weight, bias, options)
}
fn conv_transpose1d(
x: B::TensorPrimitive<3>,
weight: B::TensorPrimitive<3>,
bias: Option<B::TensorPrimitive<1>>,
options: ConvTransposeOptions<1>,
) -> B::TensorPrimitive<3> {
conv::conv_transpose1d_from_conv_transpose2d::<B>(x, weight, bias, options)
}
fn conv1d_backward(
x: B::TensorPrimitive<3>,
weight: B::TensorPrimitive<3>,
bias: Option<B::TensorPrimitive<1>>,
output_grad: B::TensorPrimitive<3>,
options: ConvOptions<1>,
) -> Conv1dBackward<B> {
conv::conv1d_backward(x, weight, bias, output_grad, options)
}
fn avg_pool2d(
x: B::TensorPrimitive<4>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
) -> B::TensorPrimitive<4>;
fn avg_pool2d_backward(
x: B::TensorPrimitive<4>,
grad: B::TensorPrimitive<4>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
) -> B::TensorPrimitive<4>;
fn max_pool2d(
x: B::TensorPrimitive<4>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
) -> B::TensorPrimitive<4>;
fn max_pool2d_with_indexes(
x: B::TensorPrimitive<4>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
) -> MaxPool2dWithIndexes<B>;
fn max_pool2d_with_indexes_backward(
x: B::TensorPrimitive<4>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
output_grad: B::TensorPrimitive<4>,
indexes: B::IntTensorPrimitive<4>,
) -> MaxPool2dBackward<B>;
}