Trait burn_tensor::backend::AutodiffBackend

source ·
pub trait AutodiffBackend: Backend {
    type InnerBackend: Backend<Device = Self::Device, FloatElem = Self::FloatElem, IntElem = Self::IntElem>;
    type Gradients: Send;

    // Required methods
    fn backward<const D: usize>(tensor: FloatTensor<Self, D>) -> Self::Gradients;
    fn grad<const D: usize>(
        tensor: &FloatTensor<Self, D>,
        grads: &Self::Gradients,
    ) -> Option<FloatTensor<Self::InnerBackend, D>>;
    fn grad_remove<const D: usize>(
        tensor: &FloatTensor<Self, D>,
        grads: &mut Self::Gradients,
    ) -> Option<FloatTensor<Self::InnerBackend, D>>;
    fn grad_replace<const D: usize>(
        tensor: &FloatTensor<Self, D>,
        grads: &mut Self::Gradients,
        grad: FloatTensor<Self::InnerBackend, D>,
    );
    fn inner<const D: usize>(
        tensor: FloatTensor<Self, D>,
    ) -> FloatTensor<Self::InnerBackend, D>;
    fn int_inner<const D: usize>(
        tensor: IntTensor<Self, D>,
    ) -> IntTensor<Self::InnerBackend, D>;
    fn bool_inner<const D: usize>(
        tensor: BoolTensor<Self, D>,
    ) -> BoolTensor<Self::InnerBackend, D>;
    fn q_inner<const D: usize>(
        tensor: QuantizedTensor<Self, D>,
    ) -> QuantizedTensor<Self::InnerBackend, D>;
    fn from_inner<const D: usize>(
        tensor: FloatTensor<Self::InnerBackend, D>,
    ) -> FloatTensor<Self, D>;
    fn int_from_inner<const D: usize>(
        tensor: IntTensor<Self::InnerBackend, D>,
    ) -> IntTensor<Self, D>;
    fn bool_from_inner<const D: usize>(
        tensor: BoolTensor<Self::InnerBackend, D>,
    ) -> BoolTensor<Self, D>;
    fn q_from_inner<const D: usize>(
        tensor: QuantizedTensor<Self::InnerBackend, D>,
    ) -> QuantizedTensor<Self, D>;
}
Expand description

Trait that allows a backend to support autodiff.

Required Associated Types§

source

type InnerBackend: Backend<Device = Self::Device, FloatElem = Self::FloatElem, IntElem = Self::IntElem>

The inner backend type.

source

type Gradients: Send

Gradients type.

Required Methods§

source

fn backward<const D: usize>(tensor: FloatTensor<Self, D>) -> Self::Gradients

Backward pass.

§Arguments
  • tensor - The tensor is the last node of computational graph where the gradients are computed.
§Returns

The gradients.

source

fn grad<const D: usize>( tensor: &FloatTensor<Self, D>, grads: &Self::Gradients, ) -> Option<FloatTensor<Self::InnerBackend, D>>

Returns the gradients of a tensor.

§Arguments
  • tensor - The tensor to extract the gradients from.
§Returns

An optional tensor containing the gradient.

source

fn grad_remove<const D: usize>( tensor: &FloatTensor<Self, D>, grads: &mut Self::Gradients, ) -> Option<FloatTensor<Self::InnerBackend, D>>

Pops the gradients of a tensor and returns them.

§Arguments
  • tensor - The tensor to pop the gradients from.
  • grads - The gradients.
§Returns

An optional tensor containing the given gradients.

source

fn grad_replace<const D: usize>( tensor: &FloatTensor<Self, D>, grads: &mut Self::Gradients, grad: FloatTensor<Self::InnerBackend, D>, )

Replace the gradients of a tensor with the one provided.

If no gradient existed for the provided tensor, register it.

§Arguments
  • tensor - The tensor to pop the gradients from.
  • grads - The gradients.
  • grad - The updated grad tensor.
source

fn inner<const D: usize>( tensor: FloatTensor<Self, D>, ) -> FloatTensor<Self::InnerBackend, D>

Returns the tensor with inner backend type.

§Arguments
  • tensor - The tensor to get the inner backend tensor for.
§Returns

The inner backend tensor.

source

fn int_inner<const D: usize>( tensor: IntTensor<Self, D>, ) -> IntTensor<Self::InnerBackend, D>

Returns the tensor with inner backend type.

§Arguments
  • tensor - The tensor to get the inner backend tensor for.
§Returns

The inner backend tensor.

source

fn bool_inner<const D: usize>( tensor: BoolTensor<Self, D>, ) -> BoolTensor<Self::InnerBackend, D>

Returns the tensor with inner backend type.

§Arguments
  • tensor - The tensor to get the inner backend tensor for.
§Returns

The inner backend tensor.

source

fn q_inner<const D: usize>( tensor: QuantizedTensor<Self, D>, ) -> QuantizedTensor<Self::InnerBackend, D>

Returns the tensor with inner backend type.

§Arguments
  • tensor - The tensor to get the inner backend tensor for.
§Returns

The inner backend tensor.

source

fn from_inner<const D: usize>( tensor: FloatTensor<Self::InnerBackend, D>, ) -> FloatTensor<Self, D>

Converts the inner backend tensor to the autodiff backend tensor.

§Arguments
  • tensor - The inner backend tensor to convert.
§Returns

The autodiff backend tensor.

source

fn int_from_inner<const D: usize>( tensor: IntTensor<Self::InnerBackend, D>, ) -> IntTensor<Self, D>

Converts the inner backend tensor to the autodiff backend tensor.

§Arguments
  • tensor - The inner backend tensor to convert.
§Returns

The autodiff backend tensor.

source

fn bool_from_inner<const D: usize>( tensor: BoolTensor<Self::InnerBackend, D>, ) -> BoolTensor<Self, D>

Converts the inner backend tensor to the autodiff backend tensor.

§Arguments
  • tensor - The inner backend tensor to convert.
§Returns

The autodiff backend tensor.

source

fn q_from_inner<const D: usize>( tensor: QuantizedTensor<Self::InnerBackend, D>, ) -> QuantizedTensor<Self, D>

Converts the inner backend tensor to the autodiff backend tensor.

§Arguments
  • tensor - The inner backend tensor to convert.
§Returns

The autodiff backend tensor.

Object Safety§

This trait is not object safe.

Implementors§