Trait AutodiffBackend

Source
pub trait AutodiffBackend: Backend {
    type InnerBackend: Backend<Device = Self::Device, FloatElem = Self::FloatElem, IntElem = Self::IntElem>;
    type Gradients: Send;

    // Required methods
    fn backward(tensor: Self::FloatTensorPrimitive) -> Self::Gradients;
    fn grad(
        tensor: &Self::FloatTensorPrimitive,
        grads: &Self::Gradients,
    ) -> Option<<Self::InnerBackend as Backend>::FloatTensorPrimitive>;
    fn grad_remove(
        tensor: &Self::FloatTensorPrimitive,
        grads: &mut Self::Gradients,
    ) -> Option<<Self::InnerBackend as Backend>::FloatTensorPrimitive>;
    fn grad_replace(
        tensor: &Self::FloatTensorPrimitive,
        grads: &mut Self::Gradients,
        grad: <Self::InnerBackend as Backend>::FloatTensorPrimitive,
    );
    fn inner(
        tensor: Self::FloatTensorPrimitive,
    ) -> <Self::InnerBackend as Backend>::FloatTensorPrimitive;
    fn int_inner(
        tensor: Self::IntTensorPrimitive,
    ) -> <Self::InnerBackend as Backend>::IntTensorPrimitive;
    fn bool_inner(
        tensor: Self::BoolTensorPrimitive,
    ) -> <Self::InnerBackend as Backend>::BoolTensorPrimitive;
    fn q_inner(
        tensor: Self::QuantizedTensorPrimitive,
    ) -> <Self::InnerBackend as Backend>::QuantizedTensorPrimitive;
    fn from_inner(
        tensor: <Self::InnerBackend as Backend>::FloatTensorPrimitive,
    ) -> Self::FloatTensorPrimitive;
    fn int_from_inner(
        tensor: <Self::InnerBackend as Backend>::IntTensorPrimitive,
    ) -> Self::IntTensorPrimitive;
    fn bool_from_inner(
        tensor: <Self::InnerBackend as Backend>::BoolTensorPrimitive,
    ) -> Self::BoolTensorPrimitive;
    fn q_from_inner(
        tensor: <Self::InnerBackend as Backend>::QuantizedTensorPrimitive,
    ) -> Self::QuantizedTensorPrimitive;
}
Expand description

Trait that allows a backend to support autodiff.

Required Associated Types§

Source

type InnerBackend: Backend<Device = Self::Device, FloatElem = Self::FloatElem, IntElem = Self::IntElem>

The inner backend type.

Source

type Gradients: Send

Gradients type.

Required Methods§

Source

fn backward(tensor: Self::FloatTensorPrimitive) -> Self::Gradients

Backward pass.

§Arguments
  • tensor - The tensor is the last node of computational graph where the gradients are computed.
§Returns

The gradients.

Source

fn grad( tensor: &Self::FloatTensorPrimitive, grads: &Self::Gradients, ) -> Option<<Self::InnerBackend as Backend>::FloatTensorPrimitive>

Returns the gradients of a tensor.

§Arguments
  • tensor - The tensor to extract the gradients from.
§Returns

An optional tensor containing the gradient.

Source

fn grad_remove( tensor: &Self::FloatTensorPrimitive, grads: &mut Self::Gradients, ) -> Option<<Self::InnerBackend as Backend>::FloatTensorPrimitive>

Pops the gradients of a tensor and returns them.

§Arguments
  • tensor - The tensor to pop the gradients from.
  • grads - The gradients.
§Returns

An optional tensor containing the given gradients.

Source

fn grad_replace( tensor: &Self::FloatTensorPrimitive, grads: &mut Self::Gradients, grad: <Self::InnerBackend as Backend>::FloatTensorPrimitive, )

Replace the gradients of a tensor with the one provided.

If no gradient existed for the provided tensor, register it.

§Arguments
  • tensor - The tensor to pop the gradients from.
  • grads - The gradients.
  • grad - The updated grad tensor.
Source

fn inner( tensor: Self::FloatTensorPrimitive, ) -> <Self::InnerBackend as Backend>::FloatTensorPrimitive

Returns the tensor with inner backend type.

§Arguments
  • tensor - The tensor to get the inner backend tensor for.
§Returns

The inner backend tensor.

Source

fn int_inner( tensor: Self::IntTensorPrimitive, ) -> <Self::InnerBackend as Backend>::IntTensorPrimitive

Returns the tensor with inner backend type.

§Arguments
  • tensor - The tensor to get the inner backend tensor for.
§Returns

The inner backend tensor.

Source

fn bool_inner( tensor: Self::BoolTensorPrimitive, ) -> <Self::InnerBackend as Backend>::BoolTensorPrimitive

Returns the tensor with inner backend type.

§Arguments
  • tensor - The tensor to get the inner backend tensor for.
§Returns

The inner backend tensor.

Source

fn q_inner( tensor: Self::QuantizedTensorPrimitive, ) -> <Self::InnerBackend as Backend>::QuantizedTensorPrimitive

Returns the tensor with inner backend type.

§Arguments
  • tensor - The tensor to get the inner backend tensor for.
§Returns

The inner backend tensor.

Source

fn from_inner( tensor: <Self::InnerBackend as Backend>::FloatTensorPrimitive, ) -> Self::FloatTensorPrimitive

Converts the inner backend tensor to the autodiff backend tensor.

§Arguments
  • tensor - The inner backend tensor to convert.
§Returns

The autodiff backend tensor.

Source

fn int_from_inner( tensor: <Self::InnerBackend as Backend>::IntTensorPrimitive, ) -> Self::IntTensorPrimitive

Converts the inner backend tensor to the autodiff backend tensor.

§Arguments
  • tensor - The inner backend tensor to convert.
§Returns

The autodiff backend tensor.

Source

fn bool_from_inner( tensor: <Self::InnerBackend as Backend>::BoolTensorPrimitive, ) -> Self::BoolTensorPrimitive

Converts the inner backend tensor to the autodiff backend tensor.

§Arguments
  • tensor - The inner backend tensor to convert.
§Returns

The autodiff backend tensor.

Source

fn q_from_inner( tensor: <Self::InnerBackend as Backend>::QuantizedTensorPrimitive, ) -> Self::QuantizedTensorPrimitive

Converts the inner backend tensor to the autodiff backend tensor.

§Arguments
  • tensor - The inner backend tensor to convert.
§Returns

The autodiff backend tensor.

Dyn Compatibility§

This trait is not dyn compatible.

In older versions of Rust, dyn compatibility was called "object safety", so this trait is not object safe.

Implementors§