Trait burn_tensor::backend::AutodiffBackend
source · pub trait AutodiffBackend: Backend {
type InnerBackend: Backend<Device = Self::Device, FloatElem = Self::FloatElem, IntElem = Self::IntElem>;
type Gradients: Send;
// Required methods
fn backward<const D: usize>(tensor: FloatTensor<Self, D>) -> Self::Gradients;
fn grad<const D: usize>(
tensor: &FloatTensor<Self, D>,
grads: &Self::Gradients,
) -> Option<FloatTensor<Self::InnerBackend, D>>;
fn grad_remove<const D: usize>(
tensor: &FloatTensor<Self, D>,
grads: &mut Self::Gradients,
) -> Option<FloatTensor<Self::InnerBackend, D>>;
fn grad_replace<const D: usize>(
tensor: &FloatTensor<Self, D>,
grads: &mut Self::Gradients,
grad: FloatTensor<Self::InnerBackend, D>,
);
fn inner<const D: usize>(
tensor: FloatTensor<Self, D>,
) -> FloatTensor<Self::InnerBackend, D>;
fn int_inner<const D: usize>(
tensor: IntTensor<Self, D>,
) -> IntTensor<Self::InnerBackend, D>;
fn bool_inner<const D: usize>(
tensor: BoolTensor<Self, D>,
) -> BoolTensor<Self::InnerBackend, D>;
fn q_inner<const D: usize>(
tensor: QuantizedTensor<Self, D>,
) -> QuantizedTensor<Self::InnerBackend, D>;
fn from_inner<const D: usize>(
tensor: FloatTensor<Self::InnerBackend, D>,
) -> FloatTensor<Self, D>;
fn int_from_inner<const D: usize>(
tensor: IntTensor<Self::InnerBackend, D>,
) -> IntTensor<Self, D>;
fn bool_from_inner<const D: usize>(
tensor: BoolTensor<Self::InnerBackend, D>,
) -> BoolTensor<Self, D>;
fn q_from_inner<const D: usize>(
tensor: QuantizedTensor<Self::InnerBackend, D>,
) -> QuantizedTensor<Self, D>;
}
Expand description
Trait that allows a backend to support autodiff.
Required Associated Types§
Required Methods§
sourcefn backward<const D: usize>(tensor: FloatTensor<Self, D>) -> Self::Gradients
fn backward<const D: usize>(tensor: FloatTensor<Self, D>) -> Self::Gradients
sourcefn grad<const D: usize>(
tensor: &FloatTensor<Self, D>,
grads: &Self::Gradients,
) -> Option<FloatTensor<Self::InnerBackend, D>>
fn grad<const D: usize>( tensor: &FloatTensor<Self, D>, grads: &Self::Gradients, ) -> Option<FloatTensor<Self::InnerBackend, D>>
sourcefn grad_remove<const D: usize>(
tensor: &FloatTensor<Self, D>,
grads: &mut Self::Gradients,
) -> Option<FloatTensor<Self::InnerBackend, D>>
fn grad_remove<const D: usize>( tensor: &FloatTensor<Self, D>, grads: &mut Self::Gradients, ) -> Option<FloatTensor<Self::InnerBackend, D>>
sourcefn grad_replace<const D: usize>(
tensor: &FloatTensor<Self, D>,
grads: &mut Self::Gradients,
grad: FloatTensor<Self::InnerBackend, D>,
)
fn grad_replace<const D: usize>( tensor: &FloatTensor<Self, D>, grads: &mut Self::Gradients, grad: FloatTensor<Self::InnerBackend, D>, )
Replace the gradients of a tensor with the one provided.
If no gradient existed for the provided tensor, register it.
§Arguments
tensor
- The tensor to pop the gradients from.grads
- The gradients.grad
- The updated grad tensor.
sourcefn inner<const D: usize>(
tensor: FloatTensor<Self, D>,
) -> FloatTensor<Self::InnerBackend, D>
fn inner<const D: usize>( tensor: FloatTensor<Self, D>, ) -> FloatTensor<Self::InnerBackend, D>
sourcefn int_inner<const D: usize>(
tensor: IntTensor<Self, D>,
) -> IntTensor<Self::InnerBackend, D>
fn int_inner<const D: usize>( tensor: IntTensor<Self, D>, ) -> IntTensor<Self::InnerBackend, D>
sourcefn bool_inner<const D: usize>(
tensor: BoolTensor<Self, D>,
) -> BoolTensor<Self::InnerBackend, D>
fn bool_inner<const D: usize>( tensor: BoolTensor<Self, D>, ) -> BoolTensor<Self::InnerBackend, D>
sourcefn q_inner<const D: usize>(
tensor: QuantizedTensor<Self, D>,
) -> QuantizedTensor<Self::InnerBackend, D>
fn q_inner<const D: usize>( tensor: QuantizedTensor<Self, D>, ) -> QuantizedTensor<Self::InnerBackend, D>
sourcefn from_inner<const D: usize>(
tensor: FloatTensor<Self::InnerBackend, D>,
) -> FloatTensor<Self, D>
fn from_inner<const D: usize>( tensor: FloatTensor<Self::InnerBackend, D>, ) -> FloatTensor<Self, D>
sourcefn int_from_inner<const D: usize>(
tensor: IntTensor<Self::InnerBackend, D>,
) -> IntTensor<Self, D>
fn int_from_inner<const D: usize>( tensor: IntTensor<Self::InnerBackend, D>, ) -> IntTensor<Self, D>
sourcefn bool_from_inner<const D: usize>(
tensor: BoolTensor<Self::InnerBackend, D>,
) -> BoolTensor<Self, D>
fn bool_from_inner<const D: usize>( tensor: BoolTensor<Self::InnerBackend, D>, ) -> BoolTensor<Self, D>
sourcefn q_from_inner<const D: usize>(
tensor: QuantizedTensor<Self::InnerBackend, D>,
) -> QuantizedTensor<Self, D>
fn q_from_inner<const D: usize>( tensor: QuantizedTensor<Self::InnerBackend, D>, ) -> QuantizedTensor<Self, D>
Object Safety§
This trait is not object safe.