pub trait ModuleOps<B: Backend> {
    // Required methods
    fn embedding(
        weights: B::TensorPrimitive<2>,
        indexes: B::IntTensorPrimitive<2>
    ) -> B::TensorPrimitive<3>;
    fn embedding_backward(
        weights: B::TensorPrimitive<2>,
        output: B::TensorPrimitive<3>,
        indexes: B::IntTensorPrimitive<2>
    ) -> B::TensorPrimitive<2>;
    fn conv2d(
        x: B::TensorPrimitive<4>,
        weight: B::TensorPrimitive<4>,
        bias: Option<B::TensorPrimitive<1>>,
        stride: [usize; 2],
        padding: [usize; 2]
    ) -> B::TensorPrimitive<4>;
    fn max_pool2d(
        x: B::TensorPrimitive<4>,
        kernel_size: [usize; 2],
        stride: [usize; 2],
        padding: [usize; 2]
    ) -> B::TensorPrimitive<4>;
    fn max_pool2d_with_indexes(
        x: B::TensorPrimitive<4>,
        kernel_size: [usize; 2],
        stride: [usize; 2],
        padding: [usize; 2]
    ) -> MaxPool2dWithIndexes<B>;
    fn max_pool2d_with_indexes_backward(
        x: B::TensorPrimitive<4>,
        kernel_size: [usize; 2],
        stride: [usize; 2],
        padding: [usize; 2],
        output_grad: B::TensorPrimitive<4>,
        indexes: B::IntTensorPrimitive<4>
    ) -> MaxPool2dBackward<B>;

    // Provided methods
    fn conv2d_backward(
        x: B::TensorPrimitive<4>,
        weight: B::TensorPrimitive<4>,
        bias: Option<B::TensorPrimitive<1>>,
        stride: [usize; 2],
        output_grad: B::TensorPrimitive<4>
    ) -> Conv2dBackward<B> { ... }
    fn conv1d(
        x: B::TensorPrimitive<3>,
        weight: B::TensorPrimitive<3>,
        bias: Option<B::TensorPrimitive<1>>,
        stride: usize,
        padding: usize
    ) -> B::TensorPrimitive<3> { ... }
    fn conv1d_backward(
        x: B::TensorPrimitive<3>,
        weight: B::TensorPrimitive<3>,
        bias: Option<B::TensorPrimitive<1>>,
        stride: usize,
        output_grad: B::TensorPrimitive<3>
    ) -> Conv1dBackward<B> { ... }
}

Required Methods§

source

fn embedding( weights: B::TensorPrimitive<2>, indexes: B::IntTensorPrimitive<2> ) -> B::TensorPrimitive<3>

source

fn embedding_backward( weights: B::TensorPrimitive<2>, output: B::TensorPrimitive<3>, indexes: B::IntTensorPrimitive<2> ) -> B::TensorPrimitive<2>

source

fn conv2d( x: B::TensorPrimitive<4>, weight: B::TensorPrimitive<4>, bias: Option<B::TensorPrimitive<1>>, stride: [usize; 2], padding: [usize; 2] ) -> B::TensorPrimitive<4>

Two dimensional convolution.

Shapes

x: [batch_size, channels_in, height, width], weight: [channels_out, channels_in, kernel_size_1, kernel_size_2], bias: [channels_out],

source

fn max_pool2d( x: B::TensorPrimitive<4>, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2] ) -> B::TensorPrimitive<4>

Two dimensional max pooling.

Shapes

x: [batch_size, channels, height, width],

source

fn max_pool2d_with_indexes( x: B::TensorPrimitive<4>, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2] ) -> MaxPool2dWithIndexes<B>

Two dimensional max pooling with indexes.

Shapes

x: [batch_size, channels, height, width],

source

fn max_pool2d_with_indexes_backward( x: B::TensorPrimitive<4>, kernel_size: [usize; 2], stride: [usize; 2], padding: [usize; 2], output_grad: B::TensorPrimitive<4>, indexes: B::IntTensorPrimitive<4> ) -> MaxPool2dBackward<B>

Backward pass for the max pooling 2d operation.

Provided Methods§

source

fn conv2d_backward( x: B::TensorPrimitive<4>, weight: B::TensorPrimitive<4>, bias: Option<B::TensorPrimitive<1>>, stride: [usize; 2], output_grad: B::TensorPrimitive<4> ) -> Conv2dBackward<B>

Backward pass for the conv2d operation.

source

fn conv1d( x: B::TensorPrimitive<3>, weight: B::TensorPrimitive<3>, bias: Option<B::TensorPrimitive<1>>, stride: usize, padding: usize ) -> B::TensorPrimitive<3>

One dimensional convolution.

Shapes

x: [batch_size, channels_in, length], weight: [channels_out, channels_in, kernel_size], bias: [channels_out],

source

fn conv1d_backward( x: B::TensorPrimitive<3>, weight: B::TensorPrimitive<3>, bias: Option<B::TensorPrimitive<1>>, stride: usize, output_grad: B::TensorPrimitive<3> ) -> Conv1dBackward<B>

Backward pass for the conv1d operation.

Implementors§