Trait coaster_nn::Convolution [] [src]

pub trait Convolution<F>: NN<F> {
    fn new_convolution_config(
        &self,
        src: &SharedTensor<F>,
        dest: &SharedTensor<F>,
        filter: &SharedTensor<F>,
        algo_fwd: ConvForwardAlgo,
        algo_bwd_filter: ConvBackwardFilterAlgo,
        algo_bwd_data: ConvBackwardDataAlgo,
        stride: &[i32],
        zero_padding: &[i32]
    ) -> Result<Self::CC, Error>; fn convolution(
        &self,
        filter: &SharedTensor<F>,
        x: &SharedTensor<F>,
        result: &mut SharedTensor<F>,
        workspace: &mut SharedTensor<u8>,
        config: &Self::CC
    ) -> Result<(), Error>; fn convolution_grad_filter(
        &self,
        src_data: &SharedTensor<F>,
        dest_diff: &SharedTensor<F>,
        filter_diff: &mut SharedTensor<F>,
        workspace: &mut SharedTensor<u8>,
        config: &Self::CC
    ) -> Result<(), Error>; fn convolution_grad_data(
        &self,
        filter: &SharedTensor<F>,
        x_diff: &SharedTensor<F>,
        result_diff: &mut SharedTensor<F>,
        workspace: &mut SharedTensor<u8>,
        config: &Self::CC
    ) -> Result<(), Error>; }

Provides the functionality for a Backend to support Convolution operations.

Required Methods

Creates a new ConvolutionConfig, which needs to be passed to further convolution Operations.

Computes a CNN convolution over the input Tensor x.

Saves the result to result.

Computes the gradient of a CNN convolution with respect to the filter.

Saves the result to filter_diff.

Computes the gradient of a CNN convolution over the input Tensor x with respect to the data.

Saves the result to result_diff.

Implementors