burn_tensor/tensor/api/
autodiff.rs

1pub use burn_backend::tensor::BasicAutodiffOps;
2
3use crate::{Tensor, TensorPrimitive, backend::AutodiffBackend};
4
5impl<const D: usize, B: AutodiffBackend> Tensor<B, D> {
6    /// Backward pass of the tensor.
7    pub fn backward(&self) -> B::Gradients {
8        B::backward(self.primitive.clone().tensor())
9    }
10
11    /// Get the gradients of a tensor if it exist.
12    ///
13    /// Returns a new reference to the same tensor. Therefore the same grad tensor can
14    /// be accessed multiple times. If you only need to get the gradients one time,
15    /// consider using [grad_remove](Tensor::grad_remove) for better performance.
16    pub fn grad(&self, grads: &B::Gradients) -> Option<Tensor<B::InnerBackend, D>> {
17        match &self.primitive {
18            TensorPrimitive::Float(tensor) => B::grad(tensor, grads)
19                .map(TensorPrimitive::Float)
20                .map(Tensor::new),
21            TensorPrimitive::QFloat(_tensor) => B::grad(&self.primitive.clone().tensor(), grads)
22                .map(TensorPrimitive::Float)
23                .map(Tensor::new),
24        }
25    }
26
27    /// Remove the grad tensor from the [grads](AutodiffBackend::Gradients) struct returning the result.
28    pub fn grad_remove(&self, grads: &mut B::Gradients) -> Option<Tensor<B::InnerBackend, D>> {
29        match &self.primitive {
30            TensorPrimitive::Float(tensor) => B::grad_remove(tensor, grads)
31                .map(TensorPrimitive::Float)
32                .map(Tensor::new),
33            TensorPrimitive::QFloat(_tensor) => {
34                B::grad_remove(&self.primitive.clone().tensor(), grads)
35                    .map(TensorPrimitive::Float)
36                    .map(Tensor::new)
37            }
38        }
39    }
40
41    /// Replace the grad tensor from the [grads](AutodiffBackend::Gradients) struct with the provided
42    /// gradient.
43    pub fn grad_replace(&self, grads: &mut B::Gradients, grad: Tensor<B::InnerBackend, D>) {
44        match &self.primitive {
45            TensorPrimitive::Float(tensor) => {
46                B::grad_replace(tensor, grads, grad.primitive.tensor())
47            }
48            TensorPrimitive::QFloat(_tensor) => B::grad_replace(
49                &self.primitive.clone().tensor(),
50                grads,
51                grad.primitive.tensor(),
52            ),
53        }
54    }
55}
56
57impl<const D: usize, B: AutodiffBackend, K: BasicAutodiffOps<B>> Tensor<B, D, K> {
58    /// Returns the inner tensor without the autodiff information.
59    pub fn inner(self) -> Tensor<B::InnerBackend, D, K::InnerKind> {
60        Tensor::new(K::inner(self.primitive))
61    }
62
63    /// Convert a tensor to the autodiff backend.
64    ///
65    /// # Arguments
66    ///
67    /// * `inner` - The tensor to convert.
68    ///
69    /// # Returns
70    ///
71    /// The tensor converted to the autodiff backend.
72    pub fn from_inner(inner: Tensor<B::InnerBackend, D, K::InnerKind>) -> Self {
73        Self::new(K::from_inner(inner.primitive))
74    }
75}