pub trait ADBackend: Backend {
    type InnerBackend: Backend<Device = Self::Device, FloatElem = Self::FloatElem>;
    type Gradients: Send + Sync;

    // Required methods
    fn backward<const D: usize>(
        tensor: Self::TensorPrimitive<D>
    ) -> Self::Gradients;
    fn grad<const D: usize>(
        tensor: &Self::TensorPrimitive<D>,
        grads: &Self::Gradients
    ) -> Option<<<Self as ADBackend>::InnerBackend as Backend>::TensorPrimitive<D>>;
    fn grad_remove<const D: usize>(
        tensor: &Self::TensorPrimitive<D>,
        grads: &mut Self::Gradients
    ) -> Option<<<Self as ADBackend>::InnerBackend as Backend>::TensorPrimitive<D>>;
    fn inner<const D: usize>(
        tensor: Self::TensorPrimitive<D>
    ) -> <Self::InnerBackend as Backend>::TensorPrimitive<D>;
    fn from_inner<const D: usize>(
        tensor: <Self::InnerBackend as Backend>::TensorPrimitive<D>
    ) -> Self::TensorPrimitive<D>;
}
Expand description

Trait that allows a backend to support autodiff.

Required Associated Types§

source

type InnerBackend: Backend<Device = Self::Device, FloatElem = Self::FloatElem>

source

type Gradients: Send + Sync

Required Methods§

source

fn backward<const D: usize>(tensor: Self::TensorPrimitive<D>) -> Self::Gradients

source

fn grad<const D: usize>( tensor: &Self::TensorPrimitive<D>, grads: &Self::Gradients ) -> Option<<<Self as ADBackend>::InnerBackend as Backend>::TensorPrimitive<D>>

source

fn grad_remove<const D: usize>( tensor: &Self::TensorPrimitive<D>, grads: &mut Self::Gradients ) -> Option<<<Self as ADBackend>::InnerBackend as Backend>::TensorPrimitive<D>>

source

fn inner<const D: usize>( tensor: Self::TensorPrimitive<D> ) -> <Self::InnerBackend as Backend>::TensorPrimitive<D>

source

fn from_inner<const D: usize>( tensor: <Self::InnerBackend as Backend>::TensorPrimitive<D> ) -> Self::TensorPrimitive<D>

Implementors§