Trait burn_core::tensor::backend::AutodiffBackend
source · pub trait AutodiffBackend: Backend {
type InnerBackend: Backend<Device = Self::Device, FloatElem = Self::FloatElem, IntElem = Self::IntElem, FullPrecisionElem = Self::FullPrecisionElem>;
type Gradients: Send + Sync;
// Required methods
fn backward<const D: usize>(
tensor: Self::TensorPrimitive<D>
) -> Self::Gradients;
fn grad<const D: usize>(
tensor: &Self::TensorPrimitive<D>,
grads: &Self::Gradients
) -> Option<<Self::InnerBackend as Backend>::TensorPrimitive<D>>;
fn grad_remove<const D: usize>(
tensor: &Self::TensorPrimitive<D>,
grads: &mut Self::Gradients
) -> Option<<Self::InnerBackend as Backend>::TensorPrimitive<D>>;
fn grad_replace<const D: usize>(
tensor: &Self::TensorPrimitive<D>,
grads: &mut Self::Gradients,
grad: <Self::InnerBackend as Backend>::TensorPrimitive<D>
);
fn inner<const D: usize>(
tensor: Self::TensorPrimitive<D>
) -> <Self::InnerBackend as Backend>::TensorPrimitive<D>;
fn int_inner<const D: usize>(
tensor: Self::IntTensorPrimitive<D>
) -> <Self::InnerBackend as Backend>::IntTensorPrimitive<D>;
fn bool_inner<const D: usize>(
tensor: Self::BoolTensorPrimitive<D>
) -> <Self::InnerBackend as Backend>::BoolTensorPrimitive<D>;
fn from_inner<const D: usize>(
tensor: <Self::InnerBackend as Backend>::TensorPrimitive<D>
) -> Self::TensorPrimitive<D>;
fn int_from_inner<const D: usize>(
tensor: <Self::InnerBackend as Backend>::IntTensorPrimitive<D>
) -> Self::IntTensorPrimitive<D>;
fn bool_from_inner<const D: usize>(
tensor: <Self::InnerBackend as Backend>::BoolTensorPrimitive<D>
) -> Self::BoolTensorPrimitive<D>;
}Expand description
Trait that allows a backend to support autodiff.
Required Associated Types§
sourcetype InnerBackend: Backend<Device = Self::Device, FloatElem = Self::FloatElem, IntElem = Self::IntElem, FullPrecisionElem = Self::FullPrecisionElem>
type InnerBackend: Backend<Device = Self::Device, FloatElem = Self::FloatElem, IntElem = Self::IntElem, FullPrecisionElem = Self::FullPrecisionElem>
The inner backend type.
Required Methods§
sourcefn backward<const D: usize>(tensor: Self::TensorPrimitive<D>) -> Self::Gradients
fn backward<const D: usize>(tensor: Self::TensorPrimitive<D>) -> Self::Gradients
sourcefn grad<const D: usize>(
tensor: &Self::TensorPrimitive<D>,
grads: &Self::Gradients
) -> Option<<Self::InnerBackend as Backend>::TensorPrimitive<D>>
fn grad<const D: usize>( tensor: &Self::TensorPrimitive<D>, grads: &Self::Gradients ) -> Option<<Self::InnerBackend as Backend>::TensorPrimitive<D>>
sourcefn grad_remove<const D: usize>(
tensor: &Self::TensorPrimitive<D>,
grads: &mut Self::Gradients
) -> Option<<Self::InnerBackend as Backend>::TensorPrimitive<D>>
fn grad_remove<const D: usize>( tensor: &Self::TensorPrimitive<D>, grads: &mut Self::Gradients ) -> Option<<Self::InnerBackend as Backend>::TensorPrimitive<D>>
sourcefn grad_replace<const D: usize>(
tensor: &Self::TensorPrimitive<D>,
grads: &mut Self::Gradients,
grad: <Self::InnerBackend as Backend>::TensorPrimitive<D>
)
fn grad_replace<const D: usize>( tensor: &Self::TensorPrimitive<D>, grads: &mut Self::Gradients, grad: <Self::InnerBackend as Backend>::TensorPrimitive<D> )
Replace the gradients of a tensor with the one provided.
If no gradient existed for the provided tensor, register it.
Arguments
tensor- The tensor to pop the gradients from.grads- The gradients.grad- The updated grad tensor.
sourcefn inner<const D: usize>(
tensor: Self::TensorPrimitive<D>
) -> <Self::InnerBackend as Backend>::TensorPrimitive<D>
fn inner<const D: usize>( tensor: Self::TensorPrimitive<D> ) -> <Self::InnerBackend as Backend>::TensorPrimitive<D>
sourcefn int_inner<const D: usize>(
tensor: Self::IntTensorPrimitive<D>
) -> <Self::InnerBackend as Backend>::IntTensorPrimitive<D>
fn int_inner<const D: usize>( tensor: Self::IntTensorPrimitive<D> ) -> <Self::InnerBackend as Backend>::IntTensorPrimitive<D>
sourcefn bool_inner<const D: usize>(
tensor: Self::BoolTensorPrimitive<D>
) -> <Self::InnerBackend as Backend>::BoolTensorPrimitive<D>
fn bool_inner<const D: usize>( tensor: Self::BoolTensorPrimitive<D> ) -> <Self::InnerBackend as Backend>::BoolTensorPrimitive<D>
sourcefn from_inner<const D: usize>(
tensor: <Self::InnerBackend as Backend>::TensorPrimitive<D>
) -> Self::TensorPrimitive<D>
fn from_inner<const D: usize>( tensor: <Self::InnerBackend as Backend>::TensorPrimitive<D> ) -> Self::TensorPrimitive<D>
sourcefn int_from_inner<const D: usize>(
tensor: <Self::InnerBackend as Backend>::IntTensorPrimitive<D>
) -> Self::IntTensorPrimitive<D>
fn int_from_inner<const D: usize>( tensor: <Self::InnerBackend as Backend>::IntTensorPrimitive<D> ) -> Self::IntTensorPrimitive<D>
sourcefn bool_from_inner<const D: usize>(
tensor: <Self::InnerBackend as Backend>::BoolTensorPrimitive<D>
) -> Self::BoolTensorPrimitive<D>
fn bool_from_inner<const D: usize>( tensor: <Self::InnerBackend as Backend>::BoolTensorPrimitive<D> ) -> Self::BoolTensorPrimitive<D>
Object Safety§
This trait is not object safe.