Convolution

Trait Convolution 

Source
pub trait Convolution<F>: NN<F> {
    // Required methods
    fn new_convolution_config(
        &self,
        src: &SharedTensor<F>,
        dest: &SharedTensor<F>,
        filter: &SharedTensor<F>,
        algo_fwd: ConvForwardAlgo,
        algo_bwd_filter: ConvBackwardFilterAlgo,
        algo_bwd_data: ConvBackwardDataAlgo,
        stride: &[i32],
        zero_padding: &[i32],
    ) -> Result<Self::CC, Error>;
    fn convolution(
        &self,
        filter: &SharedTensor<F>,
        x: &SharedTensor<F>,
        result: &mut SharedTensor<F>,
        workspace: &mut SharedTensor<u8>,
        config: &Self::CC,
    ) -> Result<(), Error>;
    fn convolution_grad_filter(
        &self,
        src_data: &SharedTensor<F>,
        dest_diff: &SharedTensor<F>,
        filter_diff: &mut SharedTensor<F>,
        workspace: &mut SharedTensor<u8>,
        config: &Self::CC,
    ) -> Result<(), Error>;
    fn convolution_grad_data(
        &self,
        filter: &SharedTensor<F>,
        x_diff: &SharedTensor<F>,
        result_diff: &mut SharedTensor<F>,
        workspace: &mut SharedTensor<u8>,
        config: &Self::CC,
    ) -> Result<(), Error>;
}
Expand description

Provides the functionality for a Backend to support Convolution operations.

Required Methods§

Source

fn new_convolution_config( &self, src: &SharedTensor<F>, dest: &SharedTensor<F>, filter: &SharedTensor<F>, algo_fwd: ConvForwardAlgo, algo_bwd_filter: ConvBackwardFilterAlgo, algo_bwd_data: ConvBackwardDataAlgo, stride: &[i32], zero_padding: &[i32], ) -> Result<Self::CC, Error>

Creates a new ConvolutionConfig, which needs to be passed to further convolution Operations.

Source

fn convolution( &self, filter: &SharedTensor<F>, x: &SharedTensor<F>, result: &mut SharedTensor<F>, workspace: &mut SharedTensor<u8>, config: &Self::CC, ) -> Result<(), Error>

Computes a [CNN convolution][convolution] over the input Tensor x. [convolution]: https://en.wikipedia.org/wiki/Convolutional_neural_network

Saves the result to result.

Source

fn convolution_grad_filter( &self, src_data: &SharedTensor<F>, dest_diff: &SharedTensor<F>, filter_diff: &mut SharedTensor<F>, workspace: &mut SharedTensor<u8>, config: &Self::CC, ) -> Result<(), Error>

Computes the gradient of a [CNN convolution][convolution] with respect to the filter. [convolution]: https://en.wikipedia.org/wiki/Convolutional_neural_network

Saves the result to filter_diff.

Source

fn convolution_grad_data( &self, filter: &SharedTensor<F>, x_diff: &SharedTensor<F>, result_diff: &mut SharedTensor<F>, workspace: &mut SharedTensor<u8>, config: &Self::CC, ) -> Result<(), Error>

Computes the gradient of a [CNN convolution][convolution] over the input Tensor x with respect to the data. [convolution]: https://en.wikipedia.org/wiki/Convolutional_neural_network

Saves the result to result_diff.

Dyn Compatibility§

This trait is not dyn compatible.

In older versions of Rust, dyn compatibility was called "object safety", so this trait is not object safe.

Implementations on Foreign Types§

Source§

impl<T> Convolution<T> for Backend<Cuda>
where T: Float + DataTypeInfo,

Source§

fn new_convolution_config( &self, src: &SharedTensor<T>, dest: &SharedTensor<T>, filter: &SharedTensor<T>, algo_fwd: ConvForwardAlgo, algo_bwd_filter: ConvBackwardFilterAlgo, algo_bwd_data: ConvBackwardDataAlgo, stride: &[i32], zero_padding: &[i32], ) -> Result<Self::CC, Error>

Source§

fn convolution( &self, filter: &SharedTensor<T>, x: &SharedTensor<T>, result: &mut SharedTensor<T>, workspace: &mut SharedTensor<u8>, config: &Self::CC, ) -> Result<(), Error>

Source§

fn convolution_grad_filter( &self, src_data: &SharedTensor<T>, dest_diff: &SharedTensor<T>, filter_diff: &mut SharedTensor<T>, workspace: &mut SharedTensor<u8>, config: &Self::CC, ) -> Result<(), Error>

Source§

fn convolution_grad_data( &self, filter: &SharedTensor<T>, x_diff: &SharedTensor<T>, result_diff: &mut SharedTensor<T>, workspace: &mut SharedTensor<u8>, config: &Self::CC, ) -> Result<(), Error>

Source§

impl<T> Convolution<T> for Backend<Native>
where T: Add<T, Output = T> + Mul<T, Output = T> + Default + Copy,

Source§

fn new_convolution_config( &self, src: &SharedTensor<T>, dest: &SharedTensor<T>, filter: &SharedTensor<T>, algo_fwd: ConvForwardAlgo, algo_bwd_filter: ConvBackwardFilterAlgo, algo_bwd_data: ConvBackwardDataAlgo, stride: &[i32], zero_padding: &[i32], ) -> Result<Self::CC, Error>

Source§

fn convolution( &self, filter: &SharedTensor<T>, x: &SharedTensor<T>, result: &mut SharedTensor<T>, _workspace: &mut SharedTensor<u8>, config: &Self::CC, ) -> Result<(), Error>

Source§

fn convolution_grad_filter( &self, src_data: &SharedTensor<T>, dest_diff: &SharedTensor<T>, filter_diff: &mut SharedTensor<T>, workspace: &mut SharedTensor<u8>, config: &Self::CC, ) -> Result<(), Error>

Source§

fn convolution_grad_data( &self, filter: &SharedTensor<T>, x_diff: &SharedTensor<T>, result_diff: &mut SharedTensor<T>, workspace: &mut SharedTensor<u8>, config: &Self::CC, ) -> Result<(), Error>

Implementors§