[][src]Trait auto_diff::tensor::convolution::Convolution

pub trait Convolution {
    type TensorType;
    fn conv2d(
        &self,
        filter: &Self::TensorType,
        stride: (usize, usize),
        padding: (usize, usize),
        dilation: (usize, usize),
        padding_mode: PaddingMode
    ) -> Self::TensorType;
fn conv2d_grad(
        &self,
        filter: &Self::TensorType,
        stride: (usize, usize),
        padding: (usize, usize),
        dilation: (usize, usize),
        padding_mode: PaddingMode,
        output_grad: &Self::TensorType
    ) -> (Self::TensorType, Self::TensorType);
fn conv_gen(
        &self,
        filter: &Self::TensorType,
        stride: &[usize],
        padding: &[usize],
        dilation: &[usize],
        padding_mode: PaddingMode
    ) -> Self::TensorType;
fn conv_grad_gen(
        &self,
        filter: &Self::TensorType,
        stride: &[usize],
        padding: &[usize],
        dilation: &[usize],
        padding_mode: PaddingMode,
        output_grad: &Self::TensorType
    ) -> (Self::TensorType, Self::TensorType); }

Associated Types

Loading content...

Required methods

fn conv2d(
    &self,
    filter: &Self::TensorType,
    stride: (usize, usize),
    padding: (usize, usize),
    dilation: (usize, usize),
    padding_mode: PaddingMode
) -> Self::TensorType

fn conv2d_grad(
    &self,
    filter: &Self::TensorType,
    stride: (usize, usize),
    padding: (usize, usize),
    dilation: (usize, usize),
    padding_mode: PaddingMode,
    output_grad: &Self::TensorType
) -> (Self::TensorType, Self::TensorType)

fn conv_gen(
    &self,
    filter: &Self::TensorType,
    stride: &[usize],
    padding: &[usize],
    dilation: &[usize],
    padding_mode: PaddingMode
) -> Self::TensorType

fn conv_grad_gen(
    &self,
    filter: &Self::TensorType,
    stride: &[usize],
    padding: &[usize],
    dilation: &[usize],
    padding_mode: PaddingMode,
    output_grad: &Self::TensorType
) -> (Self::TensorType, Self::TensorType)

Loading content...

Implementors

impl<T> Convolution for GenTensor<T> where
    T: Float
[src]

type TensorType = GenTensor<T>

Loading content...